Create different strategy based on TF1/2 in strategy_combinations
TF1 and TF2 strategy APIs have diverged, and are going to diverge further. We should create corresponding strategies in tests, so that TF1 tests can be left untouched while we change TF2 APIs. PiperOrigin-RevId: 336391144 Change-Id: I242e03341ec24442b82705a86ec8b5e3dff2ddbb
This commit is contained in:
parent
b0f7ab99e7
commit
1435e86cb6
@ -862,13 +862,21 @@ distribute_py_test(
|
||||
disable_mlir_bridge = False,
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":central_storage_strategy",
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":mirrored_strategy",
|
||||
":one_device_strategy",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -44,6 +44,28 @@ CollectiveAllReduceExtended = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceExtended)
|
||||
|
||||
|
||||
def _version_chooser(tf1_cls, tf2_cls):
|
||||
|
||||
def creator(*args, **kwargs):
|
||||
if tf2.enabled():
|
||||
return tf2_cls(*args, **kwargs)
|
||||
return tf1_cls(*args, **kwargs)
|
||||
|
||||
return creator
|
||||
|
||||
|
||||
MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1,
|
||||
mirrored_lib.MirroredStrategy)
|
||||
CentralStorageStrategy = _version_chooser(
|
||||
central_storage_strategy.CentralStorageStrategyV1,
|
||||
central_storage_strategy.CentralStorageStrategy)
|
||||
OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1,
|
||||
one_device_lib.OneDeviceStrategy)
|
||||
# Only V2 CollectiveAllReduceStrategy combinations are supported.
|
||||
CollectiveAllReduceStrategy = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
def _get_tpu_strategy_creator(steps_per_run,
|
||||
use_single_core=False,
|
||||
@ -79,8 +101,8 @@ def _get_tpu_strategy_creator(steps_per_run,
|
||||
device_assignment = None
|
||||
if use_single_core:
|
||||
device_assignment = device_assignment_lib.DeviceAssignment(
|
||||
topology, core_assignment=device_assignment_lib.
|
||||
SINGLE_CORE_ASSIGNMENT)
|
||||
topology,
|
||||
core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
|
||||
|
||||
# Steps per run is only supported in TF 1.x
|
||||
if tf2.enabled():
|
||||
@ -120,8 +142,7 @@ def _get_multi_worker_mirrored_creator(required_gpus):
|
||||
# configures the eager context. The eager context can no longer be
|
||||
# configured after initialization.
|
||||
with context.eager_mode():
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||
cluster_resolver=resolver)
|
||||
strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver)
|
||||
# TODO(b/152320929): Wait for the cluster before proceeding, otherwise
|
||||
# collectives may hang if any worker launches collectives before the chief
|
||||
# creates the strategy.
|
||||
@ -143,20 +164,16 @@ default_strategy = combinations.NamedDistribution(
|
||||
distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access
|
||||
required_gpus=None)
|
||||
one_device_strategy = combinations.NamedDistribution(
|
||||
"OneDeviceCPU",
|
||||
lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
|
||||
required_gpus=None)
|
||||
"OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None)
|
||||
one_device_strategy_gpu = combinations.NamedDistribution(
|
||||
"OneDeviceGPU",
|
||||
lambda: one_device_lib.OneDeviceStrategy("/gpu:0"),
|
||||
required_gpus=1)
|
||||
"OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1)
|
||||
one_device_strategy_on_worker_1 = combinations.NamedDistribution(
|
||||
"OneDeviceOnWorker1CPU",
|
||||
lambda: one_device_lib.OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), # pylint: disable=line-too-long
|
||||
lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"),
|
||||
required_gpus=None)
|
||||
one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution(
|
||||
"OneDeviceOnWorker1GPU",
|
||||
lambda: one_device_lib.OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), # pylint: disable=line-too-long
|
||||
lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"),
|
||||
required_gpus=1)
|
||||
tpu_strategy = combinations.NamedDistribution(
|
||||
"TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True)
|
||||
@ -180,35 +197,32 @@ cloud_tpu_strategy = combinations.NamedDistribution(
|
||||
required_tpu=True,
|
||||
use_cloud_tpu=True)
|
||||
mirrored_strategy_with_one_cpu = combinations.NamedDistribution(
|
||||
"Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"]))
|
||||
"Mirrored1CPU", lambda: MirroredStrategy(["/cpu:0"]))
|
||||
mirrored_strategy_with_one_gpu = combinations.NamedDistribution(
|
||||
"Mirrored1GPU",
|
||||
lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]),
|
||||
required_gpus=1)
|
||||
"Mirrored1GPU", lambda: MirroredStrategy(["/gpu:0"]), required_gpus=1)
|
||||
mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
|
||||
"MirroredCPUAndGPU",
|
||||
lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]),
|
||||
lambda: MirroredStrategy(["/gpu:0", "/cpu:0"]),
|
||||
required_gpus=1)
|
||||
mirrored_strategy_with_two_gpus = combinations.NamedDistribution(
|
||||
"Mirrored2GPUs",
|
||||
lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]),
|
||||
lambda: MirroredStrategy(["/gpu:0", "/gpu:1"]),
|
||||
required_gpus=2)
|
||||
# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods.
|
||||
mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution(
|
||||
"Mirrored2CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:1", "/cpu:2"]))
|
||||
"Mirrored2CPU", lambda: MirroredStrategy(["/cpu:1", "/cpu:2"]))
|
||||
mirrored_strategy_with_cpu_1_and_2.__doc__ = (
|
||||
"""Mirrored strategy with 2 virtual CPUs.
|
||||
|
||||
Should call set_virtual_cpus_to_at_least(3) in the test's setUp methods.
|
||||
Should set up logical devices before use
|
||||
""")
|
||||
central_storage_strategy_with_two_gpus = combinations.NamedDistribution(
|
||||
"CentralStorage2GPUs",
|
||||
lambda: central_storage_strategy.CentralStorageStrategy._from_num_gpus(2), # pylint: disable=protected-access
|
||||
lambda: CentralStorageStrategy._from_num_gpus(2), # pylint: disable=protected-access
|
||||
required_gpus=2)
|
||||
central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
|
||||
"CentralStorageCPUAndGPU",
|
||||
lambda: central_storage_strategy.CentralStorageStrategy(
|
||||
["/gpu:0", "/cpu:0"]),
|
||||
lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]),
|
||||
required_gpus=1)
|
||||
# chief + 1 worker, with CPU.
|
||||
multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution(
|
||||
@ -310,8 +324,7 @@ multidevice_strategies = [
|
||||
]
|
||||
|
||||
multiworker_strategies = [
|
||||
multi_worker_mirrored_2x1_cpu,
|
||||
multi_worker_mirrored_2x1_gpu,
|
||||
multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu,
|
||||
multi_worker_mirrored_2x2_gpu
|
||||
]
|
||||
|
||||
|
@ -20,10 +20,16 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import one_device_strategy
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -78,5 +84,112 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(2, self.evaluate(num_replicas))
|
||||
|
||||
|
||||
class V1StrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
tf2.disable()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.one_device_strategy_gpu_on_worker_1,
|
||||
strategy_combinations.one_device_strategy_on_worker_1
|
||||
]))
|
||||
def testOneDevice(self, strategy):
|
||||
self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategyV1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
]))
|
||||
def testMirrored(self, strategy):
|
||||
self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategyV1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_4x1_cpu,
|
||||
]))
|
||||
def testMultiWorkerMirrored(self, strategy):
|
||||
# MultiWorkerMirroredStrategy combinations only supports V2.
|
||||
self.assertIsInstance(
|
||||
strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
]))
|
||||
def testCentralStorage(self, strategy):
|
||||
self.assertIsInstance(strategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=strategy_combinations.tpu_strategies))
|
||||
def testTPU(self, strategy):
|
||||
self.assertIsInstance(strategy, tpu_strategy.TPUStrategyV1)
|
||||
|
||||
|
||||
class V2StrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
tf2.enable()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.one_device_strategy_gpu_on_worker_1,
|
||||
strategy_combinations.one_device_strategy_on_worker_1
|
||||
]))
|
||||
def testOneDevice(self, strategy):
|
||||
self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategy)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
]))
|
||||
def testMirrored(self, strategy):
|
||||
self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategy)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_4x1_cpu,
|
||||
]))
|
||||
def testMultiWorkerMirrored(self, strategy):
|
||||
self.assertIsInstance(
|
||||
strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=[
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
]))
|
||||
def testCentralStorage(self, strategy):
|
||||
self.assertIsInstance(strategy,
|
||||
central_storage_strategy.CentralStorageStrategy)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(strategy=strategy_combinations.tpu_strategies))
|
||||
def testTPU(self, strategy):
|
||||
self.assertIsInstance(strategy, tpu_strategy.TPUStrategy)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_util.main()
|
||||
|
@ -1371,7 +1371,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
metrics=metrics)
|
||||
|
||||
batch_size = 8
|
||||
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
|
||||
if isinstance(distribution, (mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.MirroredStrategyV1)):
|
||||
# MirroredStrategy uses global batch size.
|
||||
batch_size = 8 * distribution.num_replicas_in_sync
|
||||
|
||||
@ -2011,7 +2012,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
model.compile(optimizer, 'mae')
|
||||
|
||||
if isinstance(distribution,
|
||||
central_storage_strategy.CentralStorageStrategy):
|
||||
(central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
with self.assertRaisesRegex(ValueError, 'not supported'):
|
||||
model.fit(x, y, batch_size=10, epochs=1)
|
||||
else:
|
||||
@ -2023,7 +2025,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
combinations.combine(distribution=all_strategies, mode=['eager']))
|
||||
def test_custom_gradient_transformation(self, distribution):
|
||||
if isinstance(distribution,
|
||||
central_storage_strategy.CentralStorageStrategy):
|
||||
(central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
self.skipTest('Not supported with `CentralStorageStrategy`')
|
||||
|
||||
class MyLayer(keras.layers.Layer):
|
||||
|
@ -340,6 +340,7 @@ def compare_results(results_with_ds,
|
||||
# so use larger tolerance for now. Predict should be related to weights.
|
||||
if (isinstance(distribution,
|
||||
(mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.MirroredStrategyV1,
|
||||
distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access
|
||||
key.startswith(('weights_1', 'weights_2', 'predict_result'))):
|
||||
return relaxed_tolerance
|
||||
|
@ -92,7 +92,8 @@ def make_gradient_clipnorm_fn(clipnorm):
|
||||
def gradient_clipnorm_fn(grads_and_vars):
|
||||
|
||||
if isinstance(distribute_ctx.get_strategy(),
|
||||
central_storage_strategy.CentralStorageStrategy):
|
||||
(central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
raise ValueError(
|
||||
"`clipnorm` is not supported with `CenteralStorageStrategy`")
|
||||
|
||||
@ -112,7 +113,8 @@ def make_global_gradient_clipnorm_fn(clipnorm):
|
||||
def gradient_clipnorm_fn(grads_and_vars):
|
||||
|
||||
if isinstance(distribute_ctx.get_strategy(),
|
||||
central_storage_strategy.CentralStorageStrategy):
|
||||
(central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
raise ValueError(
|
||||
"`global_clipnorm` is not supported with `CenteralStorageStrategy`")
|
||||
|
||||
@ -132,7 +134,8 @@ def make_gradient_clipvalue_fn(clipvalue):
|
||||
def gradient_clipvalue_fn(grads_and_vars):
|
||||
|
||||
if isinstance(distribute_ctx.get_strategy(),
|
||||
central_storage_strategy.CentralStorageStrategy):
|
||||
(central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1)):
|
||||
raise ValueError(
|
||||
"`clipvalue` is not supported with `CenteralStorageStrategy`")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user