From b9d4d5d56018c59f96065872a8f99ff2f660a581 Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Fri, 20 Nov 2020 14:44:36 -0800 Subject: [PATCH] In optimizer, enters captured distribution strategy scope (if any) when creating slot variables. This allows users to put code that creates slot variables outside of a strategy scope, which is nice. One common use case is checkpoint restore. For slot variables that are created in `apply_gradients`, this change won't affect them as `apply_gradients` is often called inside the `fn` passed to `strategy.run`, and we enter the scope automatically inside `strategy.run`. If optimizer captured scope is not the same as the one used to create variables, or if slot variable creation uses a different scope other than the one optimizer captured, an error will be raised. Meanwhile, also changes the behavior of slot variable restoration from delayed to immediate when a strategy is present, achieving more consistent behavior (see b/172323399) and avoiding double initialization. PiperOrigin-RevId: 343568263 Change-Id: I2c19a592638a718e11974fcfb17a5a2b98139bf8 --- tensorflow/python/keras/distribute/BUILD | 1 + .../distribute/distribute_strategy_test.py | 38 ++++++++++++++++++ .../python/keras/optimizer_v2/optimizer_v2.py | 40 +++++++++++-------- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 43b5cf63677..34c0fc202db 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -328,6 +328,7 @@ py_library( "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 69aa1c7e9eb..972eab4b999 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -40,6 +40,7 @@ from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.distribute import values as ds_values_lib from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -2677,6 +2678,43 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase): loss=keras.losses.MeanSquaredError(), metrics=[keras.metrics.BinaryAccuracy()]) + @ds_combinations.generate( + combinations.combine( + distribution=strategy_combinations.mirrored_strategy_with_one_cpu, + mode=['eager'])) + def test_optimizer(self, distribution): + temp_dir = os.path.join(self.get_temp_dir(), 'ckpt') + + def create_model(): + model = keras.models.Sequential([ + keras.layers.Dense(1), + ]) + model.compile(optimizer='adam', loss='mse') + model.build([None, 1]) # create weights. + self.assertEmpty(model.optimizer.weights) + return model + + model = create_model() + x = y = array_ops.ones(shape=(1, 1)) + model.fit(x=x, y=y, batch_size=1) + model.save_weights(temp_dir) + + with distribution.scope(): + model = create_model() + model.load_weights(temp_dir) + self.assertNotEmpty(model.optimizer.weights) + self.assertIsInstance(model.optimizer.weights[0], + ds_values_lib.DistributedVariable) + + with distribution.scope(): + model = create_model() + # create/restore slot variables outside of scope is fine. + model.load_weights(temp_dir) + self.assertNotEmpty(model.optimizer.weights) + self.assertIsInstance(model.optimizer.weights[0], + ds_values_lib.DistributedVariable) + + if __name__ == '__main__': base_layer_utils.enable_v2_dtype_behavior() multi_process_runner.test_main() diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index ad6c2f22e85..a7d937344f4 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -884,22 +884,24 @@ class OptimizerV2(trackable.Trackable): initializer, shape=slot_shape, dtype=var.dtype) else: initial_value = initializer - strategy = distribute_ctx.get_strategy() - if not strategy.extended.variable_created_in_scope(var): - raise ValueError( - "Trying to create optimizer slot variable under the scope for " - "tf.distribute.Strategy ({}), which is different from the scope " - "used for the original variable ({}). Make sure the slot " - "variables are created under the same strategy scope. This may " - "happen if you're restoring from a checkpoint outside the scope" - .format(strategy, var)) - with strategy.extended.colocate_vars_with(var): - weight = tf_variables.Variable( - name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access - dtype=var.dtype, - trainable=False, - initial_value=initial_value) + with self._distribution_strategy_scope(): + strategy = distribute_ctx.get_strategy() + if not strategy.extended.variable_created_in_scope(var): + raise ValueError( + "Trying to create optimizer slot variable under the scope for " + "tf.distribute.Strategy ({}), which is different from the scope " + "used for the original variable ({}). Make sure the slot " + "variables are created under the same strategy scope. This may " + "happen if you're restoring from a checkpoint outside the scope" + .format(strategy, var)) + + with strategy.extended.colocate_vars_with(var): + weight = tf_variables.Variable( + name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access + dtype=var.dtype, + trainable=False, + initial_value=initial_value) backend.track_variable(weight) slot_dict[slot_name] = weight self._restore_slot_variable( @@ -1359,7 +1361,13 @@ class OptimizerV2(trackable.Trackable): # a slot variable if not for this case). Deferring is mostly harmless # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. - and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access + # + # One notable case is with distribution strategy, which uses variable + # creator scope but always desires the `variable` and the slot to use + # the same scope, thus we can safely eagerly create/restore slot + # variables. + and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access + self._distribution_strategy)): initializer = trackable.CheckpointInitialValueCallable( checkpoint_position=slot_variable_position) slot_variable = self.add_slot(