[retry] Use same var key in _create_slots/get_slot in V1 optimizer
We have special handling for distributed variable in get_slot, but not create_slot, while these keys need to match. This change modifies get_slot to use _var_key as well to avoid confusion. It is also to prepare for a upcoming refactor in dist strat code. Note that we need to make sure the keys don't change, so existing checkpoints can still be used. A bunch of build rules are modified to break cyclic dependencies. PiperOrigin-RevId: 354341520 Change-Id: Ifd9786263024a11806ddde0c3bd1d36157ab8db7
This commit is contained in:
parent
852c7a1716
commit
8a356e8ca5
@ -44,6 +44,9 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/autograph/utils",
|
||||
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||
"//tensorflow/python/data/experimental/ops:scan_ops",
|
||||
"//tensorflow/python/data/experimental/ops:take_while_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -11,9 +11,18 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:control_flow_v2_toggles",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/data/experimental/ops:counter",
|
||||
"//tensorflow/python/data/experimental/ops:interleave_ops",
|
||||
"//tensorflow/python/data/experimental/ops:random_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -299,6 +299,7 @@ py_library(
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":batching",
|
||||
":error_ops",
|
||||
":interleave_ops",
|
||||
":optimization",
|
||||
":parsing_ops",
|
||||
|
@ -31,6 +31,7 @@ py_library(
|
||||
"//tensorflow/python/data/experimental/ops:optimization_options",
|
||||
"//tensorflow/python/data/experimental/ops:stats_options",
|
||||
"//tensorflow/python/data/experimental/ops:threading_options",
|
||||
"//tensorflow/python/data/util:convert",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:options",
|
||||
"//tensorflow/python/data/util:random_seed",
|
||||
|
@ -150,7 +150,7 @@ py_library(
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
],
|
||||
@ -289,7 +289,6 @@ py_library(
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/autograph/core",
|
||||
@ -509,7 +508,13 @@ py_library(
|
||||
":values",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||
"//tensorflow/python/data/experimental/ops:distribute",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:multi_device_iterator_ops",
|
||||
"//tensorflow/python/data/ops:optional_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
],
|
||||
@ -810,7 +815,7 @@ py_library(
|
||||
"//tensorflow/python:resource_variable_ops_gen",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:tape",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
"//tensorflow/python/tpu:tpu_py",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -394,6 +394,7 @@ py_library(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:distribute_utils",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
@ -514,6 +515,7 @@ py_library(
|
||||
srcs = ["summary_io.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -1216,6 +1218,9 @@ cuda_py_tests(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:cross_device_ops",
|
||||
"//tensorflow/python/distribute:distribute_utils",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -25,6 +25,7 @@ import abc
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -81,10 +82,17 @@ def _deduplicate_indexed_slices(values, indices):
|
||||
|
||||
|
||||
def _var_key(var):
|
||||
# TODO(ashankar): Consolidate handling for eager and graph
|
||||
"""Returns slot key for `var`."""
|
||||
# pylint: disable=protected-access
|
||||
if hasattr(var, "_distributed_container"):
|
||||
var = var._distributed_container()
|
||||
if (distribute_utils.is_distributed_variable(var) and
|
||||
not ops.executing_eagerly_outside_functions()):
|
||||
return (var.graph, var._shared_name)
|
||||
if hasattr(var, "op"):
|
||||
return (var.op.graph, var.op.name)
|
||||
return var._unique_id # pylint: disable=protected-access
|
||||
return var._unique_id
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
@ -751,26 +759,16 @@ class Optimizer(
|
||||
Returns:
|
||||
The `Variable` for the slot if it was created, `None` otherwise.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
named_slots = self._slots.get(name, None)
|
||||
if not named_slots:
|
||||
return None
|
||||
|
||||
if hasattr(var, "_distributed_container"):
|
||||
# NOTE: If this isn't patched, then there is no `handle` in
|
||||
# `_resource_apply_dense`.
|
||||
distributed_container = var._distributed_container()
|
||||
assert distributed_container is not None
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
key = distributed_container._unique_id
|
||||
else:
|
||||
key = (distributed_container.graph, distributed_container._shared_name)
|
||||
# pylint: enable=protected-access
|
||||
mirrored_slot = named_slots.get(key, None)
|
||||
if mirrored_slot is None: return None
|
||||
return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access
|
||||
|
||||
return named_slots.get(_var_key(var), None)
|
||||
slot = named_slots.get(_var_key(var), None)
|
||||
if (distribute_utils.is_distributed_variable(slot) and
|
||||
not distribute_utils.is_distributed_variable(var)):
|
||||
# Make sure var and slot are either both DistributedVariable, or both
|
||||
# per replica variables.
|
||||
slot = slot._get_on_device_or_primary() # pylint: disable=protected-access
|
||||
return slot
|
||||
|
||||
def get_slot_names(self):
|
||||
"""Return a list of the names of slots created by the `Optimizer`.
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import cross_device_ops
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -29,6 +32,7 @@ from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
@ -269,6 +273,28 @@ class OptimizerTest(test.TestCase):
|
||||
self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
|
||||
self.assertAllClose([0., 0.], self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGetSlotUnderDistributedStrategy(self):
|
||||
# Only run this test in graph mode so we don't need actual GPU.
|
||||
ds = mirrored_strategy.MirroredStrategy(
|
||||
['CPU:0', 'GPU:0'],
|
||||
cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce())
|
||||
# We need an optimizer that creates slots.
|
||||
optimizer = adam.AdamOptimizer()
|
||||
|
||||
def f():
|
||||
v = variables.Variable([1.0])
|
||||
self.assertTrue(distribute_utils.is_distributed_variable(v))
|
||||
# Slot variables are created in the first call to apply_gradients.
|
||||
optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)])
|
||||
self.assertTrue(optimizer.get_slot_names())
|
||||
for name in optimizer.get_slot_names():
|
||||
slot = optimizer.get_slot(v, name)
|
||||
self.assertIsNotNone(slot)
|
||||
self.assertTrue(distribute_utils.is_distributed_variable(slot))
|
||||
|
||||
ds.run(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -111,6 +111,7 @@ py_library(
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/training:optimizer",
|
||||
"//tensorflow/python/training/saving:saveable_object",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
],
|
||||
|
Loading…
x
Reference in New Issue
Block a user