diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index ab9babf9149..908f6a7756c 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -44,6 +44,9 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/autograph/utils", + "//tensorflow/python/data/experimental/ops:cardinality", + "//tensorflow/python/data/experimental/ops:scan_ops", + "//tensorflow/python/data/experimental/ops:take_while_ops", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 7bfe34b53e5..1cc18af18e8 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -11,9 +11,18 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:control_flow_v2_toggles", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", "//tensorflow/python:tf2", - "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/experimental/ops:counter", + "//tensorflow/python/data/experimental/ops:interleave_ops", + "//tensorflow/python/data/experimental/ops:random_ops", + "//tensorflow/python/data/experimental/ops:readers", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:monitoring", + "//tensorflow/python/util:tf_export", ], ) diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 9e0df45b49e..71cfc178297 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -299,6 +299,7 @@ py_library( srcs_version = "PY3", deps = [ ":batching", + ":error_ops", ":interleave_ops", ":optimization", ":parsing_ops", diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 5825ac98607..3cc84fad509 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -31,6 +31,7 @@ py_library( "//tensorflow/python/data/experimental/ops:optimization_options", "//tensorflow/python/data/experimental/ops:stats_options", "//tensorflow/python/data/experimental/ops:threading_options", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:options", "//tensorflow/python/data/util:random_seed", diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 1a2d463adfc..cbe15de59e1 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -150,7 +150,7 @@ py_library( "//tensorflow/python:summary_ops_v2", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python/data", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/ops/losses", "//tensorflow/tools/docs:doc_controls", ], @@ -289,7 +289,6 @@ py_library( "//tensorflow/python:pywrap_tfe", "//tensorflow/python:summary_ops_v2", "//tensorflow/python:tensor_util", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/autograph/core", @@ -509,7 +508,13 @@ py_library( ":values", "//tensorflow/python:framework_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/experimental/ops:batching", + "//tensorflow/python/data/experimental/ops:cardinality", + "//tensorflow/python/data/experimental/ops:distribute", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:multi_device_iterator_ops", + "//tensorflow/python/data/ops:optional_ops", "//tensorflow/python/eager:context", "//tensorflow/python/ops/ragged:ragged_tensor", ], @@ -810,7 +815,7 @@ py_library( "//tensorflow/python:resource_variable_ops_gen", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", - "//tensorflow/python/tpu:tpu_lib", + "//tensorflow/python/tpu:tpu_py", ], ) diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index de4fed7ea07..0299d561d4b 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -394,6 +394,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:distribute_utils", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", @@ -514,6 +515,7 @@ py_library( srcs = ["summary_io.py"], srcs_version = "PY3", deps = [ + "//tensorflow/python:summary", "//tensorflow/python:util", ], ) @@ -1216,6 +1218,9 @@ cuda_py_tests( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:distribute_utils", + "//tensorflow/python/distribute:mirrored_strategy", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9e7d486123c..42beafac708 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -25,6 +25,7 @@ import abc import six from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop @@ -81,10 +82,17 @@ def _deduplicate_indexed_slices(values, indices): def _var_key(var): - # TODO(ashankar): Consolidate handling for eager and graph + """Returns slot key for `var`.""" + # pylint: disable=protected-access + if hasattr(var, "_distributed_container"): + var = var._distributed_container() + if (distribute_utils.is_distributed_variable(var) and + not ops.executing_eagerly_outside_functions()): + return (var.graph, var._shared_name) if hasattr(var, "op"): return (var.op.graph, var.op.name) - return var._unique_id # pylint: disable=protected-access + return var._unique_id + # pylint: enable=protected-access @six.add_metaclass(abc.ABCMeta) @@ -751,26 +759,16 @@ class Optimizer( Returns: The `Variable` for the slot if it was created, `None` otherwise. """ - # pylint: disable=protected-access named_slots = self._slots.get(name, None) if not named_slots: return None - - if hasattr(var, "_distributed_container"): - # NOTE: If this isn't patched, then there is no `handle` in - # `_resource_apply_dense`. - distributed_container = var._distributed_container() - assert distributed_container is not None - if ops.executing_eagerly_outside_functions(): - key = distributed_container._unique_id - else: - key = (distributed_container.graph, distributed_container._shared_name) - # pylint: enable=protected-access - mirrored_slot = named_slots.get(key, None) - if mirrored_slot is None: return None - return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access - - return named_slots.get(_var_key(var), None) + slot = named_slots.get(_var_key(var), None) + if (distribute_utils.is_distributed_variable(slot) and + not distribute_utils.is_distributed_variable(var)): + # Make sure var and slot are either both DistributedVariable, or both + # per replica variables. + slot = slot._get_on_device_or_primary() # pylint: disable=protected-access + return slot def get_slot_names(self): """Return a list of the names of slots created by the `Optimizer`. diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 80689085587..96236da3dc4 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import cross_device_ops +from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,6 +32,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -269,6 +273,28 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-0.1, -0.1], self.evaluate(var0)) self.assertAllClose([0., 0.], self.evaluate(var1)) + @test_util.run_deprecated_v1 + def testGetSlotUnderDistributedStrategy(self): + # Only run this test in graph mode so we don't need actual GPU. + ds = mirrored_strategy.MirroredStrategy( + ['CPU:0', 'GPU:0'], + cross_device_ops=cross_device_ops.HierarchicalCopyAllReduce()) + # We need an optimizer that creates slots. + optimizer = adam.AdamOptimizer() + + def f(): + v = variables.Variable([1.0]) + self.assertTrue(distribute_utils.is_distributed_variable(v)) + # Slot variables are created in the first call to apply_gradients. + optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)]) + self.assertTrue(optimizer.get_slot_names()) + for name in optimizer.get_slot_names(): + slot = optimizer.get_slot(v, name) + self.assertIsNotNone(slot) + self.assertTrue(distribute_utils.is_distributed_variable(slot)) + + ds.run(f) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/training/tracking/BUILD b/tensorflow/python/training/tracking/BUILD index 37f56718a7c..ec2381946da 100644 --- a/tensorflow/python/training/tracking/BUILD +++ b/tensorflow/python/training/tracking/BUILD @@ -111,6 +111,7 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python/training:optimizer", "//tensorflow/python/training/saving:saveable_object", "//tensorflow/python/training/saving:saveable_object_util", ],