Use same var key in _create_slots/get_slot in V1 optimizer
We have special handling for distributed variable in get_slot, but not create_slot. This happens to work before but upcoming change in distributed library will break it. PiperOrigin-RevId: 301658655 Change-Id: I9fc3dd9bacb277a9a6c7d9dba743b5885cad59e4
This commit is contained in:
parent
1e0821e601
commit
019639ca42
@ -5220,7 +5220,6 @@ py_library(
|
||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras/optimizer_v2:learning_rate_schedule",
|
||||
@ -6591,9 +6590,6 @@ cuda_py_tests(
|
||||
":variable_scope",
|
||||
":variables",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -27,7 +27,6 @@ import six
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -82,17 +81,10 @@ def _deduplicate_indexed_slices(values, indices):
|
||||
|
||||
|
||||
def _var_key(var):
|
||||
"""Returns slot key for `var`."""
|
||||
# pylint: disable=protected-access
|
||||
if hasattr(var, "_distributed_container"):
|
||||
var = var._distributed_container()
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
return var._unique_id
|
||||
if ds_values.is_distributed_variable(var):
|
||||
return (var.graph, var._shared_name)
|
||||
else:
|
||||
# TODO(ashankar): Consolidate handling for eager and graph
|
||||
if hasattr(var, "op"):
|
||||
return (var.op.graph, var.op.name)
|
||||
# pylint: enable=protected-access
|
||||
return var._unique_id # pylint: disable=protected-access
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
@ -759,16 +751,26 @@ class Optimizer(
|
||||
Returns:
|
||||
The `Variable` for the slot if it was created, `None` otherwise.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
named_slots = self._slots.get(name, None)
|
||||
if not named_slots:
|
||||
return None
|
||||
slot = named_slots.get(_var_key(var), None)
|
||||
if (ds_values.is_distributed_variable(slot) and
|
||||
not ds_values.is_distributed_variable(var)):
|
||||
# Make sure var and slot are either both DistributedVariable, or both
|
||||
# per replica variables.
|
||||
slot = slot._get_closest() # pylint: disable=protected-access
|
||||
return slot
|
||||
|
||||
if hasattr(var, "_distributed_container"):
|
||||
# NOTE: If this isn't patched, then there is no `handle` in
|
||||
# `_resource_apply_dense`.
|
||||
distributed_container = var._distributed_container()
|
||||
assert distributed_container is not None
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
key = distributed_container._unique_id
|
||||
else:
|
||||
key = (distributed_container.graph, distributed_container._shared_name)
|
||||
# pylint: enable=protected-access
|
||||
mirrored_slot = named_slots.get(key, None)
|
||||
if mirrored_slot is None: return None
|
||||
return mirrored_slot._get_closest() # pylint: disable=protected-access
|
||||
|
||||
return named_slots.get(_var_key(var), None)
|
||||
|
||||
def get_slot_names(self):
|
||||
"""Return a list of the names of slots created by the `Optimizer`.
|
||||
|
@ -18,9 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import cross_device_ops
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -32,7 +29,6 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
@ -273,28 +269,6 @@ class OptimizerTest(test.TestCase):
|
||||
self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
|
||||
self.assertAllClose([0., 0.], self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGetSlotUnderDistributedStrategy(self):
|
||||
# Only run this test in graph mode so we don't need actual GPU.
|
||||
ds = mirrored_strategy.MirroredStrategy(
|
||||
['CPU:0', 'GPU:0'],
|
||||
cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce())
|
||||
# We need an optimizer that creates slots.
|
||||
optimizer = adam.AdamOptimizer()
|
||||
|
||||
def f():
|
||||
v = variables.Variable([1.0])
|
||||
self.assertTrue(ds_values.is_distributed_variable(v))
|
||||
# Slot variables are created in the first call to apply_gradients.
|
||||
optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)])
|
||||
self.assertTrue(optimizer.get_slot_names())
|
||||
for name in optimizer.get_slot_names():
|
||||
slot = optimizer.get_slot(v, name)
|
||||
self.assertIsNotNone(slot)
|
||||
self.assertTrue(ds_values.is_distributed_variable(slot))
|
||||
|
||||
ds.experimental_run_v2(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user