From 5712d2cac6a4c109fae1ad20fea9db8d4f619b72 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Apr 2020 18:25:38 -0700 Subject: [PATCH] Create Keras Optimizer non slot variables inside strategy scope if the optimizer is created inside strategy scope. PiperOrigin-RevId: 308363394 Change-Id: If64c8cf449ad08d870ed39c764bbe3e2d368fd8e --- .../distribute/distribute_strategy_test.py | 8 +-- tensorflow/python/keras/engine/training.py | 14 ----- .../python/keras/optimizer_v2/optimizer_v2.py | 55 ++++++------------- 3 files changed, 20 insertions(+), 57 deletions(-) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 86574b7bc29..874ca84cab9 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -2394,17 +2394,13 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase): # Make model with distribution strategy with distribution.scope(): model = DeterministicModel(distribution) - optimizer = keras.optimizers.adam_v2.Adam(1e-4) # Compile & evaluate the model outside of the distribution strategy scope model.compile( - optimizer=optimizer, + optimizer=keras.optimizers.adam_v2.Adam(1e-4), loss=keras.losses.MeanSquaredError(), metrics=['binary_accuracy']) - # Call `optimizer.iterations` out of strategy scope. - self.assertEqual(model.optimizer.iterations.numpy(), 0) - # Non-eager training doesn't support steps_per_epoch=None. for unused_epoch in range(2): model.fit(dataset) @@ -2433,7 +2429,7 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase): with distribution.scope(): metric = keras.metrics.BinaryAccuracy() model.compile( - optimizer=optimizer, + optimizer=keras.optimizers.adam_v2.Adam(1e-4), loss=keras.losses.MeanSquaredError(), metrics=[metric]) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 8c59100dce7..5d8c3dcf37e 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1724,20 +1724,6 @@ class Model(network.Network, version_utils.ModelVersionSelector): 'strategy scope.' % (metric, strategy) ) - # Model metrics must be created in the same distribution strategy scope - # as the model. - for opt in nest.flatten(optimizer): - for v in getattr(opt, '_weights', []): - if not strategy.extended.variable_created_in_scope(v): - raise ValueError( - 'Optimizer (%s) passed to model.compile was created inside of a ' - 'different distribution strategy scope than the model. All ' - 'optimizers must be created in the same distribution strategy ' - 'scope as the model (in this case %s). If you pass in a string ' - 'identifier for an optimizer to compile the optimizer will ' - 'automatically be created in the correct distribution ' - 'strategy scope.' % (opt, strategy)) - def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): """Maybe load initial epoch from ckpt considering possible worker recovery. diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index c55b332bfc0..4cf07033d92 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import abc -import contextlib import functools import six @@ -338,13 +337,6 @@ class OptimizerV2(trackable.Trackable): self._hypers_created = False - # Store the distribution strategy object if the optimizer is created inside - # strategy scope, so it could be used to create variables later. - if distribute_ctx.has_strategy(): - self._distribution_strategy = distribute_ctx.get_strategy() - else: - self._distribution_strategy = None - def minimize(self, loss, var_list, grad_loss=None, name=None): """Minimize `loss` by updating `var_list`. @@ -808,32 +800,30 @@ class OptimizerV2(trackable.Trackable): def _create_hypers(self): if self._hypers_created: return - with self._distribution_strategy_scope(): - # Iterate hyper values deterministically. - for name, value in sorted(self._hyper.items()): - if isinstance(value, - (ops.Tensor, tf_variables.Variable)) or callable(value): - continue - else: - self._hyper[name] = self.add_weight( - name, - shape=[], - trainable=False, - initializer=value, - aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) + # Iterate hyper values deterministically. + for name, value in sorted(self._hyper.items()): + if isinstance( + value, (ops.Tensor, tf_variables.Variable)) or callable(value): + continue + else: + self._hyper[name] = self.add_weight( + name, + shape=[], + trainable=False, + initializer=value, + aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) self._hypers_created = True @property def iterations(self): """Variable. The number of training steps this Optimizer has run.""" if self._iterations is None: - with self._distribution_strategy_scope(): - self._iterations = self.add_weight( - "iter", - shape=[], - dtype=dtypes.int64, - trainable=False, - aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) + self._iterations = self.add_weight( + "iter", + shape=[], + dtype=dtypes.int64, + trainable=False, + aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) self._weights.append(self._iterations) return self._iterations @@ -1243,15 +1233,6 @@ class OptimizerV2(trackable.Trackable): slot_name, {}).setdefault(variable_key, []).append( slot_variable_position) - @contextlib.contextmanager - def _distribution_strategy_scope(self): - """Returns the `tf.distribute.Strategy` this optimizer was created under.""" - if self._distribution_strategy and not distribute_ctx.has_strategy(): - with self._distribution_strategy.scope(): - yield self._distribution_strategy.scope() - else: - yield - def _filter_grads(grads_and_vars): """Filter out iterable with grad equal to None."""