In optimizer, enters captured distribution strategy scope (if any) when creating slot variables.
This allows users to put code that creates slot variables outside of a strategy scope, which is nice. One common use case is checkpoint restore. For slot variables that are created in `apply_gradients`, this change won't affect them as `apply_gradients` is often called inside the `fn` passed to `strategy.run`, and we enter the scope automatically inside `strategy.run`. If optimizer captured scope is not the same as the one used to create variables, or if slot variable creation uses a different scope other than the one optimizer captured, an error will be raised. Meanwhile, also changes the behavior of slot variable restoration from delayed to immediate when a strategy is present, achieving more consistent behavior (see b/172323399) and avoiding double initialization. PiperOrigin-RevId: 343568263 Change-Id: I2c19a592638a718e11974fcfb17a5a2b98139bf8
This commit is contained in:
parent
37cfaab27f
commit
b9d4d5d560
@ -328,6 +328,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:tpu_strategy",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
|
||||
@ -40,6 +40,7 @@ from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import values as ds_values_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -2677,6 +2678,43 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
|
||||
loss=keras.losses.MeanSquaredError(),
|
||||
metrics=[keras.metrics.BinaryAccuracy()])
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
mode=['eager']))
|
||||
def test_optimizer(self, distribution):
|
||||
temp_dir = os.path.join(self.get_temp_dir(), 'ckpt')
|
||||
|
||||
def create_model():
|
||||
model = keras.models.Sequential([
|
||||
keras.layers.Dense(1),
|
||||
])
|
||||
model.compile(optimizer='adam', loss='mse')
|
||||
model.build([None, 1]) # create weights.
|
||||
self.assertEmpty(model.optimizer.weights)
|
||||
return model
|
||||
|
||||
model = create_model()
|
||||
x = y = array_ops.ones(shape=(1, 1))
|
||||
model.fit(x=x, y=y, batch_size=1)
|
||||
model.save_weights(temp_dir)
|
||||
|
||||
with distribution.scope():
|
||||
model = create_model()
|
||||
model.load_weights(temp_dir)
|
||||
self.assertNotEmpty(model.optimizer.weights)
|
||||
self.assertIsInstance(model.optimizer.weights[0],
|
||||
ds_values_lib.DistributedVariable)
|
||||
|
||||
with distribution.scope():
|
||||
model = create_model()
|
||||
# create/restore slot variables outside of scope is fine.
|
||||
model.load_weights(temp_dir)
|
||||
self.assertNotEmpty(model.optimizer.weights)
|
||||
self.assertIsInstance(model.optimizer.weights[0],
|
||||
ds_values_lib.DistributedVariable)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
base_layer_utils.enable_v2_dtype_behavior()
|
||||
multi_process_runner.test_main()
|
||||
|
||||
@ -884,22 +884,24 @@ class OptimizerV2(trackable.Trackable):
|
||||
initializer, shape=slot_shape, dtype=var.dtype)
|
||||
else:
|
||||
initial_value = initializer
|
||||
strategy = distribute_ctx.get_strategy()
|
||||
if not strategy.extended.variable_created_in_scope(var):
|
||||
raise ValueError(
|
||||
"Trying to create optimizer slot variable under the scope for "
|
||||
"tf.distribute.Strategy ({}), which is different from the scope "
|
||||
"used for the original variable ({}). Make sure the slot "
|
||||
"variables are created under the same strategy scope. This may "
|
||||
"happen if you're restoring from a checkpoint outside the scope"
|
||||
.format(strategy, var))
|
||||
|
||||
with strategy.extended.colocate_vars_with(var):
|
||||
weight = tf_variables.Variable(
|
||||
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
|
||||
dtype=var.dtype,
|
||||
trainable=False,
|
||||
initial_value=initial_value)
|
||||
with self._distribution_strategy_scope():
|
||||
strategy = distribute_ctx.get_strategy()
|
||||
if not strategy.extended.variable_created_in_scope(var):
|
||||
raise ValueError(
|
||||
"Trying to create optimizer slot variable under the scope for "
|
||||
"tf.distribute.Strategy ({}), which is different from the scope "
|
||||
"used for the original variable ({}). Make sure the slot "
|
||||
"variables are created under the same strategy scope. This may "
|
||||
"happen if you're restoring from a checkpoint outside the scope"
|
||||
.format(strategy, var))
|
||||
|
||||
with strategy.extended.colocate_vars_with(var):
|
||||
weight = tf_variables.Variable(
|
||||
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
|
||||
dtype=var.dtype,
|
||||
trainable=False,
|
||||
initial_value=initial_value)
|
||||
backend.track_variable(weight)
|
||||
slot_dict[slot_name] = weight
|
||||
self._restore_slot_variable(
|
||||
@ -1359,7 +1361,13 @@ class OptimizerV2(trackable.Trackable):
|
||||
# a slot variable if not for this case). Deferring is mostly harmless
|
||||
# (aside from double initialization), and makes variable creator scopes
|
||||
# behave the same way they do when graph building.
|
||||
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||||
#
|
||||
# One notable case is with distribution strategy, which uses variable
|
||||
# creator scope but always desires the `variable` and the slot to use
|
||||
# the same scope, thus we can safely eagerly create/restore slot
|
||||
# variables.
|
||||
and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access
|
||||
self._distribution_strategy)):
|
||||
initializer = trackable.CheckpointInitialValueCallable(
|
||||
checkpoint_position=slot_variable_position)
|
||||
slot_variable = self.add_slot(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user