# NOTE from https://github.com/bojone/keras_lookahead classLookahead(object): """Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/). """
definject(self, model: keras.models.Model): """Inject the Lookahead algorithm for the given model. The following code is modified from keras's _make_train_function method. See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497 """ ifnothasattr(model, 'train_function'): raise RuntimeError('You must compile your model before using it.')
model._check_trainable_weights_consistency() metrics_tensors = [ model._all_metrics_tensors[m] for m in model.metrics_names[1:] ] if model.train_function isNone: inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights) ifnotisinstance(K.symbolic_learning_phase(), int): inputs += [K.symbolic_learning_phase()] fast_params = model._collected_trainable_weights
with K.name_scope('training'): with K.name_scope(model.optimizer.__class__.__name__): training_updates = model.optimizer.get_updates( params=fast_params, loss=model.total_loss) slow_params = [K.variable(p) for p in fast_params]
# Gets loss and metrics. Updates weights at each call. fast_train_function = K.function( inputs, [model.total_loss] + metrics_tensors, updates=fast_updates, name='fast_train_function', **model._function_kwargs)
defF(inputs): self.count += 1 R = fast_train_function(inputs) if self.count % self.k == 0: K.batch_get_value(slow_updates) K.batch_get_value(copy_updates) return R
model.train_function = F
tf 1.15 的实现
因为新版本的tf.keras的keras改动又有点大,所以这里的实现和原本的又不一样.
classLookahead(object): """Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/). """
definject(self, model: keras.models.Model): """ from tensorflow.keras `_make_train_function` refer from https://github.com/tensorflow/tensorflow/blob/590d6eef7e91a6a7392c8ffffb7b58f2e0c8bc6b/tensorflow/python/keras/engine/training.py#L2091 and https://github.com/bojone/keras_lookahead/blob/master/lookahead.py """ has_recompiled = model._recompile_weights_loss_and_weighted_metrics() model._check_trainable_weights_consistency() ifisinstance(model.optimizer, list): raise ValueError('The `optimizer` in `compile` should be a single ' 'optimizer.') # If we have re-compiled the loss/weighted metric sub-graphs then create # train function even if one exists already. This is because # `_feed_sample_weights` list has been updated on re-copmpile. ifgetattr(self, 'train_function', None) isNoneor has_recompiled: # Restore the compiled trainable state. current_trainable_state = model._get_trainable_state() model._set_trainable_state(model._compiled_trainable_state)
with K.get_graph().as_default(): with K.name_scope('training'): # Training updates training_updates = model.optimizer.get_updates( params=fast_params, loss=model.total_loss) slow_params = [K.variable(p) for p in fast_params]
fast_updates = ( training_updates + # Unconditional updates model.get_updates_for(None) + # Conditional updates relevant to this model model.get_updates_for(model.inputs))
metrics = model._get_training_eval_metrics() metrics_tensors = [ m._call_result for m in metrics ifhasattr(m, '_call_result') # pylint: disable=protected-access ]
# Gets loss and metrics. Updates weights at each call. fast_train_function = K.function( inputs, [model.total_loss] + metrics_tensors, updates=fast_updates, name='train_function', **model._function_kwargs)
defF(inputs): self.count += 1 R = fast_train_function(inputs) if self.count % self.k == 0: K.batch_get_value(slow_updates) K.batch_get_value(copy_updates) return R
setattr(model, 'train_function', F)
# Restore the current trainable state model._set_trainable_state(current_trainable_state)
tf2.0的实现
# NOTE from https://github.com/bojone/keras_lookahead classLookahead(object): """Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/). """
definject(self, model: k.models.Model): has_recompiled = model._recompile_weights_loss_and_weighted_metrics() model._check_trainable_weights_consistency() ifisinstance(model.optimizer, list): raise ValueError('The `optimizer` in `compile` should be a single ' 'optimizer.') # If we have re-compiled the loss/weighted metric sub-graphs then create # train function even if one exists already. This is because # `_feed_sample_weights` list has been updated on re-copmpile. ifgetattr(model, 'train_function', None) isNoneor has_recompiled: current_trainable_state = model._get_trainable_state() model._set_trainable_state(model._compiled_trainable_state)
with K.get_graph().as_default(): with K.name_scope('training'): # Training updates fast_params = model._collected_trainable_weights training_updates = model.optimizer.get_updates( params=fast_params, loss=model.total_loss) slow_params = [K.variable(p) for p in fast_params]
fast_updates = ( training_updates + model.get_updates_for(None) + model.get_updates_for(model.inputs) ) metrics = model._get_training_eval_metrics() metrics_tensors = [ m._call_result for m in metrics ifhasattr(m, '_call_result') # pylint: disable=protected-access ]
# Gets loss and metrics. Updates weights at each call. fast_train_function = K.function( inputs, [model.total_loss] + metrics_tensors, updates=fast_updates, name='fast_train_function', **model._function_kwargs)
defF(inputs): self.count += 1 R = fast_train_function(inputs) if self.count % self.k == 0: K.batch_get_value(slow_updates) K.batch_get_value(copy_updates) return R
setattr(model, 'train_function', F) # Restore the current trainable state model._set_trainable_state(current_trainable_state)