Create Keras Optimizer non slot variables inside strategy scope if the optimizer is created inside strategy scope.
PiperOrigin-RevId: 308864248 Change-Id: Ibec7170b1cf70794af741bd9f1230844df2ec12f
This commit is contained in:
parent
b091e42120
commit
f9f6b4cec2
@ -2447,13 +2447,17 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
|
|||||||
# Make model with distribution strategy
|
# Make model with distribution strategy
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
model = DeterministicModel(distribution)
|
model = DeterministicModel(distribution)
|
||||||
|
optimizer = keras.optimizers.adam_v2.Adam(1e-4)
|
||||||
|
|
||||||
# Compile & evaluate the model outside of the distribution strategy scope
|
# Compile & evaluate the model outside of the distribution strategy scope
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer=keras.optimizers.adam_v2.Adam(1e-4),
|
optimizer=optimizer,
|
||||||
loss=keras.losses.MeanSquaredError(),
|
loss=keras.losses.MeanSquaredError(),
|
||||||
metrics=['binary_accuracy'])
|
metrics=['binary_accuracy'])
|
||||||
|
|
||||||
|
# Call `optimizer.iterations` out of strategy scope.
|
||||||
|
self.assertEqual(model.optimizer.iterations.numpy(), 0)
|
||||||
|
|
||||||
# Non-eager training doesn't support steps_per_epoch=None.
|
# Non-eager training doesn't support steps_per_epoch=None.
|
||||||
for unused_epoch in range(2):
|
for unused_epoch in range(2):
|
||||||
model.fit(dataset)
|
model.fit(dataset)
|
||||||
@ -2482,7 +2486,7 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
|
|||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
metric = keras.metrics.BinaryAccuracy()
|
metric = keras.metrics.BinaryAccuracy()
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer=keras.optimizers.adam_v2.Adam(1e-4),
|
optimizer=optimizer,
|
||||||
loss=keras.losses.MeanSquaredError(),
|
loss=keras.losses.MeanSquaredError(),
|
||||||
metrics=[metric])
|
metrics=[metric])
|
||||||
|
|
||||||
|
@ -1737,6 +1737,20 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
|||||||
'strategy scope.' % (metric, strategy)
|
'strategy scope.' % (metric, strategy)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model metrics must be created in the same distribution strategy scope
|
||||||
|
# as the model.
|
||||||
|
for opt in nest.flatten(optimizer):
|
||||||
|
for v in getattr(opt, '_weights', []):
|
||||||
|
if not strategy.extended.variable_created_in_scope(v):
|
||||||
|
raise ValueError(
|
||||||
|
'Optimizer (%s) passed to model.compile was created inside of a '
|
||||||
|
'different distribution strategy scope than the model. All '
|
||||||
|
'optimizers must be created in the same distribution strategy '
|
||||||
|
'scope as the model (in this case %s). If you pass in a string '
|
||||||
|
'identifier for an optimizer to compile the optimizer will '
|
||||||
|
'automatically be created in the correct distribution '
|
||||||
|
'strategy scope.' % (opt, strategy))
|
||||||
|
|
||||||
def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch):
|
def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch):
|
||||||
"""Maybe load initial epoch from ckpt considering possible worker recovery.
|
"""Maybe load initial epoch from ckpt considering possible worker recovery.
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import six
|
import six
|
||||||
@ -337,6 +338,13 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
|
|
||||||
self._hypers_created = False
|
self._hypers_created = False
|
||||||
|
|
||||||
|
# Store the distribution strategy object if the optimizer is created inside
|
||||||
|
# strategy scope, so it could be used to create variables later.
|
||||||
|
if distribute_ctx.has_strategy():
|
||||||
|
self._distribution_strategy = distribute_ctx.get_strategy()
|
||||||
|
else:
|
||||||
|
self._distribution_strategy = None
|
||||||
|
|
||||||
def minimize(self, loss, var_list, grad_loss=None, name=None):
|
def minimize(self, loss, var_list, grad_loss=None, name=None):
|
||||||
"""Minimize `loss` by updating `var_list`.
|
"""Minimize `loss` by updating `var_list`.
|
||||||
|
|
||||||
@ -800,30 +808,32 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
def _create_hypers(self):
|
def _create_hypers(self):
|
||||||
if self._hypers_created:
|
if self._hypers_created:
|
||||||
return
|
return
|
||||||
# Iterate hyper values deterministically.
|
with self._distribution_strategy_scope():
|
||||||
for name, value in sorted(self._hyper.items()):
|
# Iterate hyper values deterministically.
|
||||||
if isinstance(
|
for name, value in sorted(self._hyper.items()):
|
||||||
value, (ops.Tensor, tf_variables.Variable)) or callable(value):
|
if isinstance(value,
|
||||||
continue
|
(ops.Tensor, tf_variables.Variable)) or callable(value):
|
||||||
else:
|
continue
|
||||||
self._hyper[name] = self.add_weight(
|
else:
|
||||||
name,
|
self._hyper[name] = self.add_weight(
|
||||||
shape=[],
|
name,
|
||||||
trainable=False,
|
shape=[],
|
||||||
initializer=value,
|
trainable=False,
|
||||||
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
|
initializer=value,
|
||||||
|
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
|
||||||
self._hypers_created = True
|
self._hypers_created = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def iterations(self):
|
def iterations(self):
|
||||||
"""Variable. The number of training steps this Optimizer has run."""
|
"""Variable. The number of training steps this Optimizer has run."""
|
||||||
if self._iterations is None:
|
if self._iterations is None:
|
||||||
self._iterations = self.add_weight(
|
with self._distribution_strategy_scope():
|
||||||
"iter",
|
self._iterations = self.add_weight(
|
||||||
shape=[],
|
"iter",
|
||||||
dtype=dtypes.int64,
|
shape=[],
|
||||||
trainable=False,
|
dtype=dtypes.int64,
|
||||||
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
|
trainable=False,
|
||||||
|
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
|
||||||
self._weights.append(self._iterations)
|
self._weights.append(self._iterations)
|
||||||
return self._iterations
|
return self._iterations
|
||||||
|
|
||||||
@ -1233,6 +1243,15 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
slot_name, {}).setdefault(variable_key, []).append(
|
slot_name, {}).setdefault(variable_key, []).append(
|
||||||
slot_variable_position)
|
slot_variable_position)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _distribution_strategy_scope(self):
|
||||||
|
"""Returns the `tf.distribute.Strategy` this optimizer was created under."""
|
||||||
|
if self._distribution_strategy and not distribute_ctx.has_strategy():
|
||||||
|
with self._distribution_strategy.scope():
|
||||||
|
yield self._distribution_strategy.scope()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def _filter_grads(grads_and_vars):
|
def _filter_grads(grads_and_vars):
|
||||||
"""Filter out iterable with grad equal to None."""
|
"""Filter out iterable with grad equal to None."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user