论文《Lookahead Optimizer: k steps forward, 1 step back》的tf.Keras实现.

参考自苏剑林的repo

tf 1.14 的实现

因为tf.keraskeras改动有点大,所以这里的实现和原本的不一样.

# NOTE from https://github.com/bojone/keras_lookahead
class Lookahead(object):
"""Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/).
"""

def __init__(self, k=5, alpha=0.5):
self.k = k
self.alpha = alpha
self.count = 0

def inject(self, model: keras.models.Model):
"""Inject the Lookahead algorithm for the given model.
The following code is modified from keras's _make_train_function method.
See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497
"""
if not hasattr(model, 'train_function'):
raise RuntimeError('You must compile your model before using it.')

model._check_trainable_weights_consistency()
metrics_tensors = [
model._all_metrics_tensors[m] for m in model.metrics_names[1:]
]
if model.train_function is None:
inputs = (model._feed_inputs +
model._feed_targets +
model._feed_sample_weights)
if not isinstance(K.symbolic_learning_phase(), int):
inputs += [K.symbolic_learning_phase()]
fast_params = model._collected_trainable_weights

with K.name_scope('training'):
with K.name_scope(model.optimizer.__class__.__name__):
training_updates = model.optimizer.get_updates(
params=fast_params,
loss=model.total_loss)
slow_params = [K.variable(p) for p in fast_params]

fast_updates = (model.updates +
training_updates +
model.get_updates_for(None) +
model.get_updates_for(model.inputs))

slow_updates, copy_updates = [], []
for p, q in zip(fast_params, slow_params):
slow_updates.append(K.update(q, q + self.alpha * (p - q)))
copy_updates.append(K.update(p, q))

# Gets loss and metrics. Updates weights at each call.
fast_train_function = K.function(
inputs, [model.total_loss] + metrics_tensors,
updates=fast_updates,
name='fast_train_function',
**model._function_kwargs)

def F(inputs):
self.count += 1
R = fast_train_function(inputs)
if self.count % self.k == 0:
K.batch_get_value(slow_updates)
K.batch_get_value(copy_updates)
return R

model.train_function = F

tf 1.15 的实现

因为新版本的tf.keraskeras改动又有点大,所以这里的实现和原本的又不一样.

class Lookahead(object):
"""Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/).
"""

def __init__(self, k=5, alpha=0.5):
self.k = k
self.alpha = alpha
self.count = 0

def inject(self, model: keras.models.Model):
""" from tensorflow.keras `_make_train_function` refer from
https://github.com/tensorflow/tensorflow/blob/590d6eef7e91a6a7392c8ffffb7b58f2e0c8bc6b/tensorflow/python/keras/engine/training.py#L2091 and https://github.com/bojone/keras_lookahead/blob/master/lookahead.py
"""
has_recompiled = model._recompile_weights_loss_and_weighted_metrics()
model._check_trainable_weights_consistency()
if isinstance(model.optimizer, list):
raise ValueError('The `optimizer` in `compile` should be a single '
'optimizer.')
# If we have re-compiled the loss/weighted metric sub-graphs then create
# train function even if one exists already. This is because
# `_feed_sample_weights` list has been updated on re-copmpile.
if getattr(self, 'train_function', None) is None or has_recompiled:
# Restore the compiled trainable state.
current_trainable_state = model._get_trainable_state()
model._set_trainable_state(model._compiled_trainable_state)

inputs = (model._feed_inputs +
model._feed_targets +
model._feed_sample_weights)
if not isinstance(K.symbolic_learning_phase(), int):
inputs += [K.symbolic_learning_phase()]

fast_params = model._collected_trainable_weights

with K.get_graph().as_default():
with K.name_scope('training'):
# Training updates
training_updates = model.optimizer.get_updates(
params=fast_params, loss=model.total_loss)
slow_params = [K.variable(p) for p in fast_params]

fast_updates = (
training_updates +
# Unconditional updates
model.get_updates_for(None) +
# Conditional updates relevant to this model
model.get_updates_for(model.inputs))

metrics = model._get_training_eval_metrics()
metrics_tensors = [
m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access
]

with K.name_scope('training'):
slow_updates, copy_updates = [], []
for p, q in zip(fast_params, slow_params):
slow_updates.append(K.update(q, q + self.alpha * (p - q)))
copy_updates.append(K.update(p, q))

# Gets loss and metrics. Updates weights at each call.
fast_train_function = K.function(
inputs, [model.total_loss] + metrics_tensors,
updates=fast_updates,
name='train_function',
**model._function_kwargs)

def F(inputs):
self.count += 1
R = fast_train_function(inputs)
if self.count % self.k == 0:
K.batch_get_value(slow_updates)
K.batch_get_value(copy_updates)
return R

setattr(model, 'train_function', F)

# Restore the current trainable state
model._set_trainable_state(current_trainable_state)

tf2.0的实现

# NOTE from https://github.com/bojone/keras_lookahead
class Lookahead(object):
"""Add the [Lookahead Optimizer](https://arxiv.org/abs/1907.08610) functionality for [keras](https://keras.io/).
"""

def __init__(self, k=5, alpha=0.5):
self.k = k
self.alpha = alpha
self.count = 0

def inject(self, model: k.models.Model):
has_recompiled = model._recompile_weights_loss_and_weighted_metrics()
model._check_trainable_weights_consistency()
if isinstance(model.optimizer, list):
raise ValueError('The `optimizer` in `compile` should be a single '
'optimizer.')
# If we have re-compiled the loss/weighted metric sub-graphs then create
# train function even if one exists already. This is because
# `_feed_sample_weights` list has been updated on re-copmpile.
if getattr(model, 'train_function', None) is None or has_recompiled:
current_trainable_state = model._get_trainable_state()
model._set_trainable_state(model._compiled_trainable_state)

inputs = (model._feed_inputs +
model._feed_targets +
model._feed_sample_weights)
if not isinstance(K.symbolic_learning_phase(), int):
inputs += [K.symbolic_learning_phase()]

with K.get_graph().as_default():
with K.name_scope('training'):
# Training updates
fast_params = model._collected_trainable_weights
training_updates = model.optimizer.get_updates(
params=fast_params, loss=model.total_loss)
slow_params = [K.variable(p) for p in fast_params]

fast_updates = (
training_updates +
model.get_updates_for(None) +
model.get_updates_for(model.inputs)
)
metrics = model._get_training_eval_metrics()
metrics_tensors = [
m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access
]

with K.name_scope('training'):
slow_updates, copy_updates = [], []
for p, q in zip(fast_params, slow_params):
slow_updates.append(K.update(q, q + self.alpha * (p - q)))
copy_updates.append(K.update(p, q))

# Gets loss and metrics. Updates weights at each call.
fast_train_function = K.function(
inputs, [model.total_loss] + metrics_tensors,
updates=fast_updates,
name='fast_train_function',
**model._function_kwargs)

def F(inputs):
self.count += 1
R = fast_train_function(inputs)
if self.count % self.k == 0:
K.batch_get_value(slow_updates)
K.batch_get_value(copy_updates)
return R

setattr(model, 'train_function', F)
# Restore the current trainable state
model._set_trainable_state(current_trainable_state)