Create Keras Optimizer non slot variables inside strategy scope if the optimizer is created inside strategy scope.

PiperOrigin-RevId: 308363394
Change-Id: If64c8cf449ad08d870ed39c764bbe3e2d368fd8e
This commit is contained in:
A. Unique TensorFlower 2020-04-24 18:25:38 -07:00 committed by TensorFlower Gardener
parent c942431b49
commit 5712d2cac6
3 changed files with 20 additions and 57 deletions

View File

@ -2394,17 +2394,13 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
# Make model with distribution strategy # Make model with distribution strategy
with distribution.scope(): with distribution.scope():
model = DeterministicModel(distribution) model = DeterministicModel(distribution)
optimizer = keras.optimizers.adam_v2.Adam(1e-4)
# Compile & evaluate the model outside of the distribution strategy scope # Compile & evaluate the model outside of the distribution strategy scope
model.compile( model.compile(
optimizer=optimizer, optimizer=keras.optimizers.adam_v2.Adam(1e-4),
loss=keras.losses.MeanSquaredError(), loss=keras.losses.MeanSquaredError(),
metrics=['binary_accuracy']) metrics=['binary_accuracy'])
# Call `optimizer.iterations` out of strategy scope.
self.assertEqual(model.optimizer.iterations.numpy(), 0)
# Non-eager training doesn't support steps_per_epoch=None. # Non-eager training doesn't support steps_per_epoch=None.
for unused_epoch in range(2): for unused_epoch in range(2):
model.fit(dataset) model.fit(dataset)
@ -2433,7 +2429,7 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
with distribution.scope(): with distribution.scope():
metric = keras.metrics.BinaryAccuracy() metric = keras.metrics.BinaryAccuracy()
model.compile( model.compile(
optimizer=optimizer, optimizer=keras.optimizers.adam_v2.Adam(1e-4),
loss=keras.losses.MeanSquaredError(), loss=keras.losses.MeanSquaredError(),
metrics=[metric]) metrics=[metric])

View File

@ -1724,20 +1724,6 @@ class Model(network.Network, version_utils.ModelVersionSelector):
'strategy scope.' % (metric, strategy) 'strategy scope.' % (metric, strategy)
) )
# Model metrics must be created in the same distribution strategy scope
# as the model.
for opt in nest.flatten(optimizer):
for v in getattr(opt, '_weights', []):
if not strategy.extended.variable_created_in_scope(v):
raise ValueError(
'Optimizer (%s) passed to model.compile was created inside of a '
'different distribution strategy scope than the model. All '
'optimizers must be created in the same distribution strategy '
'scope as the model (in this case %s). If you pass in a string '
'identifier for an optimizer to compile the optimizer will '
'automatically be created in the correct distribution '
'strategy scope.' % (opt, strategy))
def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch):
"""Maybe load initial epoch from ckpt considering possible worker recovery. """Maybe load initial epoch from ckpt considering possible worker recovery.

View File

@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import abc import abc
import contextlib
import functools import functools
import six import six
@ -338,13 +337,6 @@ class OptimizerV2(trackable.Trackable):
self._hypers_created = False self._hypers_created = False
# Store the distribution strategy object if the optimizer is created inside
# strategy scope, so it could be used to create variables later.
if distribute_ctx.has_strategy():
self._distribution_strategy = distribute_ctx.get_strategy()
else:
self._distribution_strategy = None
def minimize(self, loss, var_list, grad_loss=None, name=None): def minimize(self, loss, var_list, grad_loss=None, name=None):
"""Minimize `loss` by updating `var_list`. """Minimize `loss` by updating `var_list`.
@ -808,32 +800,30 @@ class OptimizerV2(trackable.Trackable):
def _create_hypers(self): def _create_hypers(self):
if self._hypers_created: if self._hypers_created:
return return
with self._distribution_strategy_scope(): # Iterate hyper values deterministically.
# Iterate hyper values deterministically. for name, value in sorted(self._hyper.items()):
for name, value in sorted(self._hyper.items()): if isinstance(
if isinstance(value, value, (ops.Tensor, tf_variables.Variable)) or callable(value):
(ops.Tensor, tf_variables.Variable)) or callable(value): continue
continue else:
else: self._hyper[name] = self.add_weight(
self._hyper[name] = self.add_weight( name,
name, shape=[],
shape=[], trainable=False,
trainable=False, initializer=value,
initializer=value, aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
self._hypers_created = True self._hypers_created = True
@property @property
def iterations(self): def iterations(self):
"""Variable. The number of training steps this Optimizer has run.""" """Variable. The number of training steps this Optimizer has run."""
if self._iterations is None: if self._iterations is None:
with self._distribution_strategy_scope(): self._iterations = self.add_weight(
self._iterations = self.add_weight( "iter",
"iter", shape=[],
shape=[], dtype=dtypes.int64,
dtype=dtypes.int64, trainable=False,
trainable=False, aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
self._weights.append(self._iterations) self._weights.append(self._iterations)
return self._iterations return self._iterations
@ -1243,15 +1233,6 @@ class OptimizerV2(trackable.Trackable):
slot_name, {}).setdefault(variable_key, []).append( slot_name, {}).setdefault(variable_key, []).append(
slot_variable_position) slot_variable_position)
@contextlib.contextmanager
def _distribution_strategy_scope(self):
"""Returns the `tf.distribute.Strategy` this optimizer was created under."""
if self._distribution_strategy and not distribute_ctx.has_strategy():
with self._distribution_strategy.scope():
yield self._distribution_strategy.scope()
else:
yield
def _filter_grads(grads_and_vars): def _filter_grads(grads_and_vars):
"""Filter out iterable with grad equal to None.""" """Filter out iterable with grad equal to None."""