diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 1021b5fa408..04adc0185ac 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1523,35 +1523,22 @@ cuda_py_test( ":combinations", ":cross_device_utils", ":multi_worker_test_base", - ":multi_worker_util", - ":reduce_util", ":strategy_combinations", ":strategy_test_lib", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", - "//tensorflow/python:nn", - "//tensorflow/python:random_ops", - "//tensorflow/python:training_lib", - "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", - "//tensorflow/python/keras:testing_utils", - "//tensorflow/python/keras/layers:core", - "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/keras/layers", "//tensorflow/python/keras/mixed_precision/experimental:test_util", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index 0e015596bb4..04248ee140d 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -31,7 +31,6 @@ from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util -from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver @@ -565,33 +564,6 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( num_gpus=required_gpus) -class MultiworkerMirroredStrategyTest(test.TestCase, parameterized.TestCase): - - @combinations.generate( - combinations.combine( - strategy=strategy_combinations.multi_worker_mirrored_two_workers, - mode=['eager'])) - def testReduce(self, strategy): - - def fn(): - - def replica_fn(): - return array_ops.ones((), dtypes.int32) - - per_replica_value = strategy.run(replica_fn) - return strategy.reduce( - reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None) - - # Run reduce under the strategy scope to explicitly enter - # strategy default_device scope. - with strategy.scope(): - self.assertEqual(fn().numpy(), 2) - - # Run reduce without a strategy scope to implicitly enter - # strategy default_device scope. - self.assertEqual(fn().numpy(), 2) - - class LocalCollectiveAllReduceStrategy( CollectiveAllReduceStrategyTestBase, strategy_test_lib.DistributionTestBase, @@ -701,4 +673,4 @@ class LocalCollectiveAllReduceStrategy( if __name__ == '__main__': - combinations.main() + test.main() diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 6baa15f59c1..d17a594cb5e 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -1912,8 +1912,9 @@ class StrategyExtendedV2(object): def _reduce(self, reduce_op, value): # Default implementation until we have an implementation for each strategy. - dst = device_util.current() or self._default_device or "/device:CPU:0" - return self._local_results(self.reduce_to(reduce_op, value, dst))[0] + return self._local_results( + self.reduce_to(reduce_op, value, + device_util.current() or "/device:CPU:0"))[0] def reduce_to(self, reduce_op, value, destinations, experimental_hints=None): """Combine (via e.g. sum or mean) values across replicas.