diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index ae59ff50705..f8c6b83ba83 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -862,13 +862,21 @@ distribute_py_test( disable_mlir_bridge = False, python_version = "PY3", deps = [ + ":central_storage_strategy", + ":collective_all_reduce_strategy", ":combinations", + ":mirrored_strategy", + ":one_device_strategy", ":reduce_util", ":strategy_combinations", ":test_util", + ":tpu_strategy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", "//tensorflow/python:config", "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", + "//tensorflow/python:tf2", + "//tensorflow/python/eager:def_function", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index 862ff9addbc..0f0de47dc1d 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -44,6 +44,28 @@ CollectiveAllReduceExtended = ( collective_all_reduce_strategy.CollectiveAllReduceExtended) +def _version_chooser(tf1_cls, tf2_cls): + + def creator(*args, **kwargs): + if tf2.enabled(): + return tf2_cls(*args, **kwargs) + return tf1_cls(*args, **kwargs) + + return creator + + +MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1, + mirrored_lib.MirroredStrategy) +CentralStorageStrategy = _version_chooser( + central_storage_strategy.CentralStorageStrategyV1, + central_storage_strategy.CentralStorageStrategy) +OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1, + one_device_lib.OneDeviceStrategy) +# Only V2 CollectiveAllReduceStrategy combinations are supported. +CollectiveAllReduceStrategy = ( + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + + # pylint: disable=missing-docstring def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, @@ -79,8 +101,8 @@ def _get_tpu_strategy_creator(steps_per_run, device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( - topology, core_assignment=device_assignment_lib. - SINGLE_CORE_ASSIGNMENT) + topology, + core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x if tf2.enabled(): @@ -120,8 +142,7 @@ def _get_multi_worker_mirrored_creator(required_gpus): # configures the eager context. The eager context can no longer be # configured after initialization. with context.eager_mode(): - strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - cluster_resolver=resolver) + strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver) # TODO(b/152320929): Wait for the cluster before proceeding, otherwise # collectives may hang if any worker launches collectives before the chief # creates the strategy. @@ -143,20 +164,16 @@ default_strategy = combinations.NamedDistribution( distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = combinations.NamedDistribution( - "OneDeviceCPU", - lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), - required_gpus=None) + "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None) one_device_strategy_gpu = combinations.NamedDistribution( - "OneDeviceGPU", - lambda: one_device_lib.OneDeviceStrategy("/gpu:0"), - required_gpus=1) + "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1) one_device_strategy_on_worker_1 = combinations.NamedDistribution( "OneDeviceOnWorker1CPU", - lambda: one_device_lib.OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), # pylint: disable=line-too-long + lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), required_gpus=None) one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution( "OneDeviceOnWorker1GPU", - lambda: one_device_lib.OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), # pylint: disable=line-too-long + lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), required_gpus=1) tpu_strategy = combinations.NamedDistribution( "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) @@ -180,35 +197,32 @@ cloud_tpu_strategy = combinations.NamedDistribution( required_tpu=True, use_cloud_tpu=True) mirrored_strategy_with_one_cpu = combinations.NamedDistribution( - "Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) + "Mirrored1CPU", lambda: MirroredStrategy(["/cpu:0"])) mirrored_strategy_with_one_gpu = combinations.NamedDistribution( - "Mirrored1GPU", - lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]), - required_gpus=1) + "Mirrored1GPU", lambda: MirroredStrategy(["/gpu:0"]), required_gpus=1) mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution( "MirroredCPUAndGPU", - lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), + lambda: MirroredStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) mirrored_strategy_with_two_gpus = combinations.NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), + lambda: MirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) # Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods. mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution( - "Mirrored2CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:1", "/cpu:2"])) + "Mirrored2CPU", lambda: MirroredStrategy(["/cpu:1", "/cpu:2"])) mirrored_strategy_with_cpu_1_and_2.__doc__ = ( """Mirrored strategy with 2 virtual CPUs. - Should call set_virtual_cpus_to_at_least(3) in the test's setUp methods. + Should set up logical devices before use """) central_storage_strategy_with_two_gpus = combinations.NamedDistribution( "CentralStorage2GPUs", - lambda: central_storage_strategy.CentralStorageStrategy._from_num_gpus(2), # pylint: disable=protected-access + lambda: CentralStorageStrategy._from_num_gpus(2), # pylint: disable=protected-access required_gpus=2) central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution( "CentralStorageCPUAndGPU", - lambda: central_storage_strategy.CentralStorageStrategy( - ["/gpu:0", "/cpu:0"]), + lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) # chief + 1 worker, with CPU. multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution( @@ -310,8 +324,7 @@ multidevice_strategies = [ ] multiworker_strategies = [ - multi_worker_mirrored_2x1_cpu, - multi_worker_mirrored_2x1_gpu, + multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu, multi_worker_mirrored_2x2_gpu ] diff --git a/tensorflow/python/distribute/strategy_combinations_test.py b/tensorflow/python/distribute/strategy_combinations_test.py index 1157520d654..9ee5eab93c6 100644 --- a/tensorflow/python/distribute/strategy_combinations_test.py +++ b/tensorflow/python/distribute/strategy_combinations_test.py @@ -20,10 +20,16 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python import tf2 +from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import one_device_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import test_util +from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops @@ -78,5 +84,112 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase): self.assertEqual(2, self.evaluate(num_replicas)) +class V1StrategyTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + tf2.disable() + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.one_device_strategy, + strategy_combinations.one_device_strategy_gpu, + strategy_combinations.one_device_strategy_gpu_on_worker_1, + strategy_combinations.one_device_strategy_on_worker_1 + ])) + def testOneDevice(self, strategy): + self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategyV1) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_one_gpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + ])) + def testMirrored(self, strategy): + self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategyV1) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + strategy_combinations.multi_worker_mirrored_4x1_cpu, + ])) + def testMultiWorkerMirrored(self, strategy): + # MultiWorkerMirroredStrategy combinations only supports V2. + self.assertIsInstance( + strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_two_gpus, + ])) + def testCentralStorage(self, strategy): + self.assertIsInstance(strategy, + central_storage_strategy.CentralStorageStrategyV1) + + @combinations.generate( + combinations.combine(strategy=strategy_combinations.tpu_strategies)) + def testTPU(self, strategy): + self.assertIsInstance(strategy, tpu_strategy.TPUStrategyV1) + + +class V2StrategyTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + tf2.enable() + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.one_device_strategy, + strategy_combinations.one_device_strategy_gpu, + strategy_combinations.one_device_strategy_gpu_on_worker_1, + strategy_combinations.one_device_strategy_on_worker_1 + ])) + def testOneDevice(self, strategy): + self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_one_gpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + ])) + def testMirrored(self, strategy): + self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + strategy_combinations.multi_worker_mirrored_4x1_cpu, + ])) + def testMultiWorkerMirrored(self, strategy): + self.assertIsInstance( + strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_two_gpus, + ])) + def testCentralStorage(self, strategy): + self.assertIsInstance(strategy, + central_storage_strategy.CentralStorageStrategy) + + @combinations.generate( + combinations.combine(strategy=strategy_combinations.tpu_strategies)) + def testTPU(self, strategy): + self.assertIsInstance(strategy, tpu_strategy.TPUStrategy) + + if __name__ == "__main__": test_util.main() diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index b0784d08b05..7bc70101fb4 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -1371,7 +1371,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, metrics=metrics) batch_size = 8 - if isinstance(distribution, mirrored_strategy.MirroredStrategy): + if isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.MirroredStrategyV1)): # MirroredStrategy uses global batch size. batch_size = 8 * distribution.num_replicas_in_sync @@ -2011,7 +2012,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, model.compile(optimizer, 'mae') if isinstance(distribution, - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): with self.assertRaisesRegex(ValueError, 'not supported'): model.fit(x, y, batch_size=10, epochs=1) else: @@ -2023,7 +2025,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, combinations.combine(distribution=all_strategies, mode=['eager'])) def test_custom_gradient_transformation(self, distribution): if isinstance(distribution, - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): self.skipTest('Not supported with `CentralStorageStrategy`') class MyLayer(keras.layers.Layer): diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index a0c5f7a1299..77a5f290439 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -340,6 +340,7 @@ def compare_results(results_with_ds, # so use larger tolerance for now. Predict should be related to weights. if (isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.MirroredStrategyV1, distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access key.startswith(('weights_1', 'weights_2', 'predict_result'))): return relaxed_tolerance diff --git a/tensorflow/python/keras/optimizer_v2/utils.py b/tensorflow/python/keras/optimizer_v2/utils.py index 44958792c10..90f9a4975e7 100644 --- a/tensorflow/python/keras/optimizer_v2/utils.py +++ b/tensorflow/python/keras/optimizer_v2/utils.py @@ -92,7 +92,8 @@ def make_gradient_clipnorm_fn(clipnorm): def gradient_clipnorm_fn(grads_and_vars): if isinstance(distribute_ctx.get_strategy(), - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): raise ValueError( "`clipnorm` is not supported with `CenteralStorageStrategy`") @@ -112,7 +113,8 @@ def make_global_gradient_clipnorm_fn(clipnorm): def gradient_clipnorm_fn(grads_and_vars): if isinstance(distribute_ctx.get_strategy(), - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): raise ValueError( "`global_clipnorm` is not supported with `CenteralStorageStrategy`") @@ -132,7 +134,8 @@ def make_gradient_clipvalue_fn(clipvalue): def gradient_clipvalue_fn(grads_and_vars): if isinstance(distribute_ctx.get_strategy(), - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): raise ValueError( "`clipvalue` is not supported with `CenteralStorageStrategy`")