[retry] Use same var key in _create_slots/get_slot in V1 optimizer

We have special handling for distributed variable in get_slot, but not
create_slot, while these keys need to match. This change modifies get_slot to use _var_key as well to avoid confusion. It is also to prepare for a upcoming refactor in dist strat code.

Note that we need to make sure the keys don't change, so existing checkpoints can still be used.

A bunch of build rules are modified to break cyclic dependencies.

PiperOrigin-RevId: 354341520
Change-Id: Ifd9786263024a11806ddde0c3bd1d36157ab8db7
This commit is contained in:
Ran Chen 2021-01-28 10:31:34 -08:00 committed by TensorFlower Gardener
parent 852c7a1716
commit 8a356e8ca5
9 changed files with 72 additions and 23 deletions

View File

@ -44,6 +44,9 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/autograph/utils",
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/experimental/ops:scan_ops",
"//tensorflow/python/data/experimental/ops:take_while_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],

View File

@ -11,9 +11,18 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:control_flow_v2_toggles",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/experimental/ops:counter",
"//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/experimental/ops:random_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:monitoring",
"//tensorflow/python/util:tf_export",
],
)

View File

@ -299,6 +299,7 @@ py_library(
srcs_version = "PY3",
deps = [
":batching",
":error_ops",
":interleave_ops",
":optimization",
":parsing_ops",

View File

@ -31,6 +31,7 @@ py_library(
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/experimental/ops:threading_options",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:options",
"//tensorflow/python/data/util:random_seed",

View File

@ -150,7 +150,7 @@ py_library(
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/ops/losses",
"//tensorflow/tools/docs:doc_controls",
],
@ -289,7 +289,6 @@ py_library(
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:tensor_util",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/autograph/core",
@ -509,7 +508,13 @@ py_library(
":values",
"//tensorflow/python:framework_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/experimental/ops:distribute",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged:ragged_tensor",
],
@ -810,7 +815,7 @@ py_library(
"//tensorflow/python:resource_variable_ops_gen",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/tpu:tpu_py",
],
)

View File

@ -394,6 +394,7 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:distribute_utils",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
@ -514,6 +515,7 @@ py_library(
srcs = ["summary_io.py"],
srcs_version = "PY3",
deps = [
"//tensorflow/python:summary",
"//tensorflow/python:util",
],
)
@ -1216,6 +1218,9 @@ cuda_py_tests(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:cross_device_ops",
"//tensorflow/python/distribute:distribute_utils",
"//tensorflow/python/distribute:mirrored_strategy",
"//third_party/py/numpy",
"@six_archive//:six",
],

View File

@ -25,6 +25,7 @@ import abc
import six
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop
@ -81,10 +82,17 @@ def _deduplicate_indexed_slices(values, indices):
def _var_key(var):
# TODO(ashankar): Consolidate handling for eager and graph
"""Returns slot key for `var`."""
# pylint: disable=protected-access
if hasattr(var, "_distributed_container"):
var = var._distributed_container()
if (distribute_utils.is_distributed_variable(var) and
not ops.executing_eagerly_outside_functions()):
return (var.graph, var._shared_name)
if hasattr(var, "op"):
return (var.op.graph, var.op.name)
return var._unique_id # pylint: disable=protected-access
return var._unique_id
# pylint: enable=protected-access
@six.add_metaclass(abc.ABCMeta)
@ -751,26 +759,16 @@ class Optimizer(
Returns:
The `Variable` for the slot if it was created, `None` otherwise.
"""
# pylint: disable=protected-access
named_slots = self._slots.get(name, None)
if not named_slots:
return None
if hasattr(var, "_distributed_container"):
# NOTE: If this isn't patched, then there is no `handle` in
# `_resource_apply_dense`.
distributed_container = var._distributed_container()
assert distributed_container is not None
if ops.executing_eagerly_outside_functions():
key = distributed_container._unique_id
else:
key = (distributed_container.graph, distributed_container._shared_name)
# pylint: enable=protected-access
mirrored_slot = named_slots.get(key, None)
if mirrored_slot is None: return None
return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access
return named_slots.get(_var_key(var), None)
slot = named_slots.get(_var_key(var), None)
if (distribute_utils.is_distributed_variable(slot) and
not distribute_utils.is_distributed_variable(var)):
# Make sure var and slot are either both DistributedVariable, or both
# per replica variables.
slot = slot._get_on_device_or_primary() # pylint: disable=protected-access
return slot
def get_slot_names(self):
"""Return a list of the names of slots created by the `Optimizer`.

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import cross_device_ops
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -29,6 +32,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
@ -269,6 +273,28 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
self.assertAllClose([0., 0.], self.evaluate(var1))
@test_util.run_deprecated_v1
def testGetSlotUnderDistributedStrategy(self):
# Only run this test in graph mode so we don't need actual GPU.
ds = mirrored_strategy.MirroredStrategy(
['CPU:0', 'GPU:0'],
cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce())
# We need an optimizer that creates slots.
optimizer = adam.AdamOptimizer()
def f():
v = variables.Variable([1.0])
self.assertTrue(distribute_utils.is_distributed_variable(v))
# Slot variables are created in the first call to apply_gradients.
optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)])
self.assertTrue(optimizer.get_slot_names())
for name in optimizer.get_slot_names():
slot = optimizer.get_slot(v, name)
self.assertIsNotNone(slot)
self.assertTrue(distribute_utils.is_distributed_variable(slot))
ds.run(f)
if __name__ == '__main__':
test.main()

View File

@ -111,6 +111,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python/training:optimizer",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
],