In optimizer, enters captured distribution strategy scope (if any) when creating slot variables.

This allows users to put code that creates slot variables outside of a strategy scope, which is nice. One common use case is checkpoint restore. For slot variables that are created in `apply_gradients`, this change won't affect them as `apply_gradients` is often called inside the `fn` passed to `strategy.run`, and we enter the scope automatically inside `strategy.run`.

If optimizer captured scope is not the same as the one used to create variables, or if slot variable creation uses a different scope other than the one optimizer captured, an error will be raised.

Meanwhile, also changes the behavior of slot variable restoration from delayed to immediate when a strategy is present, achieving more consistent behavior (see b/172323399) and avoiding double initialization.

PiperOrigin-RevId: 343568263
Change-Id: I2c19a592638a718e11974fcfb17a5a2b98139bf8
This commit is contained in:
Chenkai Kuang 2020-11-20 14:44:36 -08:00 committed by TensorFlower Gardener
parent 37cfaab27f
commit b9d4d5d560
3 changed files with 63 additions and 16 deletions

View File

@ -328,6 +328,7 @@ py_library(
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",

View File

@ -40,6 +40,7 @@ from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values as ds_values_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -2677,6 +2678,43 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
loss=keras.losses.MeanSquaredError(),
metrics=[keras.metrics.BinaryAccuracy()])
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.mirrored_strategy_with_one_cpu,
mode=['eager']))
def test_optimizer(self, distribution):
temp_dir = os.path.join(self.get_temp_dir(), 'ckpt')
def create_model():
model = keras.models.Sequential([
keras.layers.Dense(1),
])
model.compile(optimizer='adam', loss='mse')
model.build([None, 1]) # create weights.
self.assertEmpty(model.optimizer.weights)
return model
model = create_model()
x = y = array_ops.ones(shape=(1, 1))
model.fit(x=x, y=y, batch_size=1)
model.save_weights(temp_dir)
with distribution.scope():
model = create_model()
model.load_weights(temp_dir)
self.assertNotEmpty(model.optimizer.weights)
self.assertIsInstance(model.optimizer.weights[0],
ds_values_lib.DistributedVariable)
with distribution.scope():
model = create_model()
# create/restore slot variables outside of scope is fine.
model.load_weights(temp_dir)
self.assertNotEmpty(model.optimizer.weights)
self.assertIsInstance(model.optimizer.weights[0],
ds_values_lib.DistributedVariable)
if __name__ == '__main__':
base_layer_utils.enable_v2_dtype_behavior()
multi_process_runner.test_main()

View File

@ -884,22 +884,24 @@ class OptimizerV2(trackable.Trackable):
initializer, shape=slot_shape, dtype=var.dtype)
else:
initial_value = initializer
strategy = distribute_ctx.get_strategy()
if not strategy.extended.variable_created_in_scope(var):
raise ValueError(
"Trying to create optimizer slot variable under the scope for "
"tf.distribute.Strategy ({}), which is different from the scope "
"used for the original variable ({}). Make sure the slot "
"variables are created under the same strategy scope. This may "
"happen if you're restoring from a checkpoint outside the scope"
.format(strategy, var))
with strategy.extended.colocate_vars_with(var):
weight = tf_variables.Variable(
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
dtype=var.dtype,
trainable=False,
initial_value=initial_value)
with self._distribution_strategy_scope():
strategy = distribute_ctx.get_strategy()
if not strategy.extended.variable_created_in_scope(var):
raise ValueError(
"Trying to create optimizer slot variable under the scope for "
"tf.distribute.Strategy ({}), which is different from the scope "
"used for the original variable ({}). Make sure the slot "
"variables are created under the same strategy scope. This may "
"happen if you're restoring from a checkpoint outside the scope"
.format(strategy, var))
with strategy.extended.colocate_vars_with(var):
weight = tf_variables.Variable(
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
dtype=var.dtype,
trainable=False,
initial_value=initial_value)
backend.track_variable(weight)
slot_dict[slot_name] = weight
self._restore_slot_variable(
@ -1359,7 +1361,13 @@ class OptimizerV2(trackable.Trackable):
# a slot variable if not for this case). Deferring is mostly harmless
# (aside from double initialization), and makes variable creator scopes
# behave the same way they do when graph building.
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
#
# One notable case is with distribution strategy, which uses variable
# creator scope but always desires the `variable` and the slot to use
# the same scope, thus we can safely eagerly create/restore slot
# variables.
and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access
self._distribution_strategy)):
initializer = trackable.CheckpointInitialValueCallable(
checkpoint_position=slot_variable_position)
slot_variable = self.add_slot(