Update DS tests in keras to use public TF combination symbols.

PiperOrigin-RevId: 330850095
Change-Id: I93f45c0f9c35a6c7257e182e9126d96f3a02e2c7
This commit is contained in:
Scott Zhu 2020-09-09 19:38:27 -07:00 committed by TensorFlower Gardener
parent 9d36befed5
commit 9645e47535
41 changed files with 260 additions and 230 deletions

View File

@ -21,11 +21,12 @@ import os
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
@ -36,7 +37,7 @@ from tensorflow.python.training.tracking import util as trackable_utils
class TrainingCheckpointTests(test.TestCase, parameterized.TestCase): class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
@ -95,7 +96,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
ValueError, "optimizer slot variable under the scope"): ValueError, "optimizer slot variable under the scope"):
checkpoint.restore(save_path) checkpoint.restore(save_path)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,

View File

@ -25,7 +25,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
@ -36,6 +36,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
@ -257,13 +258,13 @@ class DistributedCollectiveAllReduceStrategyTest(
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0) num_workers=3, num_ps=0)
@combinations.generate( @ds_combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
def testComplexModel(self, required_gpus): def testComplexModel(self, required_gpus):
self._run_between_graph_clients( self._run_between_graph_clients(
self._test_complex_model, self._cluster_spec, num_gpus=required_gpus) self._test_complex_model, self._cluster_spec, num_gpus=required_gpus)
@combinations.generate( @ds_combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def testMixedPrecision(self, required_gpus): def testMixedPrecision(self, required_gpus):
@ -285,13 +286,13 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0, has_chief=True) num_workers=3, num_ps=0, has_chief=True)
@combinations.generate( @ds_combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
def testComplexModel(self, required_gpus): def testComplexModel(self, required_gpus):
self._run_between_graph_clients( self._run_between_graph_clients(
self._test_complex_model, self._cluster_spec, num_gpus=required_gpus) self._test_complex_model, self._cluster_spec, num_gpus=required_gpus)
@combinations.generate( @ds_combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def testMixedPrecision(self, required_gpus): def testMixedPrecision(self, required_gpus):
@ -310,12 +311,12 @@ class LocalCollectiveAllReduceStrategy(
strategy_test_lib.TwoDeviceDistributionTestBase, strategy_test_lib.TwoDeviceDistributionTestBase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[2, 4])) combinations.combine(mode=['graph'], required_gpus=[2, 4]))
def testComplexModel(self, required_gpus): def testComplexModel(self, required_gpus):
self._test_complex_model(None, None, required_gpus) self._test_complex_model(None, None, required_gpus)
@combinations.generate( @ds_combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[2, 4])) combinations.combine(mode=['graph'], required_gpus=[2, 4]))
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def testMixedPrecision(self, required_gpus): def testMixedPrecision(self, required_gpus):
@ -323,7 +324,7 @@ class LocalCollectiveAllReduceStrategy(
self._test_mixed_precision(None, None, required_gpus) self._test_mixed_precision(None, None, required_gpus)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.multi_worker_mirrored_2x1_cpu, strategy_combinations.multi_worker_mirrored_2x1_cpu,
@ -370,4 +371,4 @@ class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
if __name__ == '__main__': if __name__ == '__main__':
v2_compat.enable_v2_behavior() v2_compat.enable_v2_behavior()
combinations.main() ds_combinations.main()

View File

@ -25,13 +25,14 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.distribute import optimizer_combinations from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -224,7 +225,7 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
np.random.seed(_RANDOM_SEED) np.random.seed(_RANDOM_SEED)
random_seed.set_random_seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
optimizer_fn=optimizer_combinations.optimizers_v2, optimizer_fn=optimizer_combinations.optimizers_v2,
@ -277,4 +278,4 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
if __name__ == '__main__': if __name__ == '__main__':
combinations.main() ds_combinations.main()

View File

@ -22,17 +22,18 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import metrics from tensorflow.python.keras import metrics
from tensorflow.python.platform import test from tensorflow.python.platform import test
class KerasMetricsTest(test.TestCase, parameterized.TestCase): class KerasMetricsTest(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies + distribution=strategy_combinations.all_strategies +
strategy_combinations.multiworker_strategies, strategy_combinations.multiworker_strategies,
@ -57,7 +58,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
loss_metric_2.result().numpy()) loss_metric_2.result().numpy())
self.assertEqual(loss_metric.result().numpy(), 5.0) self.assertEqual(loss_metric.result().numpy(), 5.0)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies+ distribution=strategy_combinations.all_strategies+
strategy_combinations.multiworker_strategies, strategy_combinations.multiworker_strategies,
@ -81,7 +82,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
# of 10 resulting in mean of 4.5. # of 10 resulting in mean of 4.5.
self.assertEqual(metric.result().numpy(), 4.5) self.assertEqual(metric.result().numpy(), 4.5)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=["eager"] mode=["eager"]
@ -100,4 +101,4 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
combinations.main() ds_combinations.main()

View File

@ -25,11 +25,12 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.module import module from tensorflow.python.module import module
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -52,7 +53,7 @@ class CustomModel(module.Module):
return x return x
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=(strategy_combinations.all_strategies + distribution=(strategy_combinations.all_strategies +
strategy_combinations.multiworker_strategies), strategy_combinations.multiworker_strategies),
@ -414,7 +415,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
class KerasModelsXLATest(test.TestCase, parameterized.TestCase): class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.tpu_strategies, mode=["eager"])) distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
def test_tf_function_experimental_compile(self, distribution): def test_tf_function_experimental_compile(self, distribution):
@ -474,4 +475,4 @@ def _get_model():
if __name__ == "__main__": if __name__ == "__main__":
combinations.main() ds_combinations.main()

View File

@ -19,12 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -32,13 +32,13 @@ from tensorflow.python.platform import test
class OptimizerTest(test.TestCase, parameterized.TestCase): class OptimizerTest(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
combinations.combine( combinations.combine(
distribution=strategy_combinations.multidevice_strategies, distribution=strategy_combinations.multidevice_strategies,
mode=["eager"], mode=["eager"],
), ),
combinations.concat( ds_combinations.concat(
combinations.combine( combinations.combine(
experimental_aggregate_gradients=True, experimental_aggregate_gradients=True,
expected=[[[-0.3, -0.3], [-0.3, -0.3]]]), expected=[[[-0.3, -0.3], [-0.3, -0.3]]]),
@ -71,7 +71,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(optimize(), expected) self.assertAllClose(optimize(), expected)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.one_device_strategy, distribution=strategy_combinations.one_device_strategy,
mode=["eager"], mode=["eager"],
@ -98,7 +98,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(optimize(), [[-0.1, -0.1]]) self.assertAllClose(optimize(), [[-0.1, -0.1]])
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=[ combinations.combine(distribution=[
strategy_combinations.central_storage_strategy_with_gpu_and_cpu strategy_combinations.central_storage_strategy_with_gpu_and_cpu
])) ]))

View File

@ -23,7 +23,7 @@ from tensorflow.python import keras
from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import central_storage_strategy from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
@ -37,6 +37,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import optimizer_combinations from tensorflow.python.keras.distribute import optimizer_combinations
@ -353,7 +354,7 @@ class BatchCountingCB(keras.callbacks.Callback):
class TestDistributionStrategyWithNumpyArrays(test.TestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calculating_input_params_no_steps_no_batch_size(self, distribution): def test_calculating_input_params_no_steps_no_batch_size(self, distribution):
# Calculate the per_replica_batch_size scaling factor for strategies # Calculate the per_replica_batch_size scaling factor for strategies
# that use per_core_batch_size # that use per_core_batch_size
@ -374,7 +375,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertEqual(batch_size, 20 // replica_scale_factor) self.assertEqual(batch_size, 20 // replica_scale_factor)
self.assertEqual(steps, 1) self.assertEqual(steps, 1)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calculating_input_params_with_steps_no_batch_size( def test_calculating_input_params_with_steps_no_batch_size(
self, distribution): self, distribution):
# Calculate the per_replica_batch_size scaling factor for strategies # Calculate the per_replica_batch_size scaling factor for strategies
@ -417,7 +418,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
distributed_training_utils.get_input_params( distributed_training_utils.get_input_params(
distribution, 63, steps=1, batch_size=None) distribution, 63, steps=1, batch_size=None)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calculating_input_params_no_steps_with_batch_size( def test_calculating_input_params_no_steps_with_batch_size(
self, distribution): self, distribution):
# Calculate the per_replica_batch_size scaling factor for strategies # Calculate the per_replica_batch_size scaling factor for strategies
@ -439,7 +440,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertEqual(batch_size, 32) self.assertEqual(batch_size, 32)
self.assertEqual(steps, 2 // replica_scale_factor) self.assertEqual(steps, 2 // replica_scale_factor)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calculating_input_params_with_steps_with_batch_size( def test_calculating_input_params_with_steps_with_batch_size(
self, distribution): self, distribution):
with self.cached_session(): with self.cached_session():
@ -454,7 +455,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
distributed_training_utils.get_input_params( distributed_training_utils.get_input_params(
distribution, 64, steps=10, batch_size=13) distribution, 64, steps=10, batch_size=13)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calling_model_with_numpy_arrays(self, distribution): def test_calling_model_with_numpy_arrays(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -488,7 +489,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
model.predict(inputs) model.predict(inputs)
model.predict(inputs, batch_size=8) model.predict(inputs, batch_size=8)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calling_model_with_mixed_precision(self, distribution): def test_calling_model_with_mixed_precision(self, distribution):
if isinstance(distribution.extended, if isinstance(distribution.extended,
parameter_server_strategy.ParameterServerStrategyExtended): parameter_server_strategy.ParameterServerStrategyExtended):
@ -534,7 +535,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
model.predict(inputs) model.predict(inputs)
model.predict(inputs, batch_size=8) model.predict(inputs, batch_size=8)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_operator_overload_mixed_precision(self, distribution): def test_operator_overload_mixed_precision(self, distribution):
# Regression test that tests a fixed bug does not reoccur. Adding an # Regression test that tests a fixed bug does not reoccur. Adding an
# AutoCastVariable to a tensor on a TPU, where the variable was the LHS of # AutoCastVariable to a tensor on a TPU, where the variable was the LHS of
@ -575,7 +576,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertIsNotNone(grad_v1) self.assertIsNotNone(grad_v1)
self.assertIsNotNone(grad_v2) self.assertIsNotNone(grad_v2)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[strategy_combinations.one_device_strategy], distribution=[strategy_combinations.one_device_strategy],
mode=['graph', 'eager'])) mode=['graph', 'eager']))
@ -593,7 +594,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
'cannot be called in cross-replica context'): 'cannot be called in cross-replica context'):
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) optimizer.apply_gradients(zip(gradients, model.trainable_variables))
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calling_model_with_nested_numpy_arrays(self, distribution): def test_calling_model_with_nested_numpy_arrays(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -624,7 +625,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
model.predict(inputs) model.predict(inputs)
model.predict(inputs, batch_size=8) model.predict(inputs, batch_size=8)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategies_minus_tpu, distribution=strategies_minus_tpu,
mode=['graph', 'eager'])) mode=['graph', 'eager']))
@ -665,7 +666,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
result = model.evaluate(inputs, targets, batch_size=2, verbose=1) result = model.evaluate(inputs, targets, batch_size=2, verbose=1)
self.assertAllClose(result, 13.5) self.assertAllClose(result, 13.5)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_flatten_predict_outputs(self, distribution): def test_flatten_predict_outputs(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -692,7 +693,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertAllEqual([6, 7], outs[0].shape) self.assertAllEqual([6, 7], outs[0].shape)
self.assertAllEqual([6, 7], outs[1].shape) self.assertAllEqual([6, 7], outs[1].shape)
@combinations.generate( @ds_combinations.generate(
combinations.times(tpu_strategy_combinations_graph_only(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(batch_size=[4, 6]))) combinations.combine(batch_size=[4, 6])))
def test_evaluate_with_partial_batch(self, distribution, batch_size): def test_evaluate_with_partial_batch(self, distribution, batch_size):
@ -735,7 +736,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
atol=1e-5, atol=1e-5,
rtol=1e-5) rtol=1e-5)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
tpu_strategy_combinations_graph_only())) tpu_strategy_combinations_graph_only()))
def test_predict_with_partial_batch(self, distribution): def test_predict_with_partial_batch(self, distribution):
@ -772,7 +773,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
atol=1e-5, atol=1e-5,
rtol=1e-5) rtol=1e-5)
@combinations.generate(tpu_strategy_combinations_graph_only()) @ds_combinations.generate(tpu_strategy_combinations_graph_only())
def test_no_target_model(self, distribution): def test_no_target_model(self, distribution):
with self.cached_session(): with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001) optimizer = gradient_descent.GradientDescentOptimizer(0.001)
@ -797,7 +798,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
model.predict(inputs, steps=1) model.predict(inputs, steps=1)
model.evaluate(inputs, steps=1) model.evaluate(inputs, steps=1)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
tpu_strategy_combinations_graph_only())) tpu_strategy_combinations_graph_only()))
def test_predict_multi_output_model_with_partial_batch( def test_predict_multi_output_model_with_partial_batch(
@ -832,7 +833,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
atol=1e-4, atol=1e-4,
rtol=1e-4) rtol=1e-4)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_gradients_are_none(self, distribution): def test_gradients_are_none(self, distribution):
if not context.executing_eagerly(): if not context.executing_eagerly():
@ -863,7 +864,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
class TestDistributionStrategyWithDatasets(test.TestCase, class TestDistributionStrategyWithDatasets(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution): def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -896,7 +897,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
validation_steps=2) validation_steps=2)
model.predict(get_predict_dataset(distribution), steps=2) model.predict(get_predict_dataset(distribution), steps=2)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_model_interleaved_eval_same_as_direct_eval( def test_model_interleaved_eval_same_as_direct_eval(
self, distribution): self, distribution):
with self.cached_session(): with self.cached_session():
@ -947,7 +948,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
self.assertEqual(interleaved_output.history['val_categorical_accuracy'], self.assertEqual(interleaved_output.history['val_categorical_accuracy'],
[x[2] for x in user_controlled_output]) [x[2] for x in user_controlled_output])
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -984,7 +985,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_fit_with_dictionary_in_the_dataset_b135161171( def test_fit_with_dictionary_in_the_dataset_b135161171(
self, distribution): self, distribution):
@ -1032,7 +1033,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
model.fit(data) model.fit(data)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset_without_steps( def test_fit_eval_and_predict_methods_on_dataset_without_steps(
self, distribution): self, distribution):
with self.cached_session(): with self.cached_session():
@ -1068,7 +1069,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
self.assertAllClose( self.assertAllClose(
predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_on_dataset_with_unknown_cardinality_without_steps( def test_on_dataset_with_unknown_cardinality_without_steps(
self, distribution, mode): self, distribution, mode):
# TODO(b/155867206): Investigate why this test occasionally segfaults on TPU # TODO(b/155867206): Investigate why this test occasionally segfaults on TPU
@ -1131,7 +1132,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
atol=1e-4, atol=1e-4,
rtol=1e-4) rtol=1e-4)
@combinations.generate(tpu_strategy_combinations_graph_only()) @ds_combinations.generate(tpu_strategy_combinations_graph_only())
def test_on_dataset_with_unknown_cardinality(self, distribution): def test_on_dataset_with_unknown_cardinality(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -1172,7 +1173,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
'Number of steps could not be inferred'): 'Number of steps could not be inferred'):
model.fit(dataset, epochs=1) model.fit(dataset, epochs=1)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset( def test_fit_eval_and_predict_methods_on_dataset(
self, distribution): self, distribution):
with self.cached_session(): with self.cached_session():
@ -1193,7 +1194,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
model.evaluate(dataset, steps=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1)
model.predict(get_predict_dataset(distribution), steps=2) model.predict(get_predict_dataset(distribution), steps=2)
@combinations.generate(strategy_and_optimizer_combinations()) @ds_combinations.generate(strategy_and_optimizer_combinations())
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
with self.cached_session(): with self.cached_session():
@ -1211,7 +1212,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
model.evaluate(dataset, steps=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1)
model.predict(get_predict_dataset(distribution), steps=2) model.predict(get_predict_dataset(distribution), steps=2)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -1243,7 +1244,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
with self.assertRaisesRegex(ValueError, 'is incompatible with'): with self.assertRaisesRegex(ValueError, 'is incompatible with'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu strategy_combinations.mirrored_strategy_with_gpu_and_cpu
@ -1269,7 +1270,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -1324,7 +1325,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
ref_output = np.ones((160, 1), dtype=np.float32) ref_output = np.ones((160, 1), dtype=np.float32)
self.assertArrayNear(output, ref_output, 1e-1) self.assertArrayNear(output, ref_output, 1e-1)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def testOptimizerWithCallbacks(self, distribution): def testOptimizerWithCallbacks(self, distribution):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
@ -1348,7 +1349,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
@combinations.generate( @ds_combinations.generate(
combinations.times(tpu_strategy_combinations_graph_only(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(batch_size=[4, 6]))) combinations.combine(batch_size=[4, 6])))
def test_evaluate_with_dataset_with_partial_batch(self, distribution, def test_evaluate_with_dataset_with_partial_batch(self, distribution,
@ -1389,7 +1390,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
atol=1e-5, atol=1e-5,
rtol=1e-5) rtol=1e-5)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
tpu_strategy_combinations_graph_only())) tpu_strategy_combinations_graph_only()))
def test_predict_with_dataset_with_partial_batch( def test_predict_with_dataset_with_partial_batch(
@ -1421,7 +1422,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
atol=1e-5, atol=1e-5,
rtol=1e-5) rtol=1e-5)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
tpu_strategy_combinations_graph_only())) tpu_strategy_combinations_graph_only()))
def test_predict_multi_output_model_with_dataset_with_partial_batch( def test_predict_multi_output_model_with_dataset_with_partial_batch(
@ -1458,7 +1459,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
atol=1e-4, atol=1e-4,
rtol=1e-4) rtol=1e-4)
@combinations.generate(all_strategy_combinations_minus_default()) @ds_combinations.generate(all_strategy_combinations_minus_default())
def test_match_model_input_matches_with_dataset_tensors(self, distribution): def test_match_model_input_matches_with_dataset_tensors(self, distribution):
def _create_model_input_output_tensors(): def _create_model_input_output_tensors():
@ -1511,7 +1512,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
atol=1e-4, atol=1e-4,
rtol=1e-4) rtol=1e-4)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategies_minus_tpu, distribution=strategies_minus_tpu,
mode=['graph', 'eager'])) mode=['graph', 'eager']))
@ -1572,7 +1573,7 @@ class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
def loss_fn(_, y_pred): def loss_fn(_, y_pred):
return math_ops.reduce_mean(y_pred) return math_ops.reduce_mean(y_pred)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
strategy_combinations.all_strategy_combinations_minus_default())) strategy_combinations.all_strategy_combinations_minus_default()))
def test_regularizer_loss(self, distribution): def test_regularizer_loss(self, distribution):
@ -1610,7 +1611,7 @@ class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
class TestDistributionStrategyWithKerasModels(test.TestCase, class TestDistributionStrategyWithKerasModels(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_distribution_strategy_on_sequential_model( def test_distribution_strategy_on_sequential_model(
self, distribution): self, distribution):
with distribution.scope(): with distribution.scope():
@ -1629,7 +1630,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
model.predict(inputs, batch_size=10) model.predict(inputs, batch_size=10)
model.evaluate(inputs, targets, batch_size=10) model.evaluate(inputs, targets, batch_size=10)
@combinations.generate(all_strategy_combinations()) @ds_combinations.generate(all_strategy_combinations())
def test_distribution_strategy_on_functional_model( def test_distribution_strategy_on_functional_model(
self, distribution): self, distribution):
with distribution.scope(): with distribution.scope():
@ -1648,7 +1649,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
model.predict(inputs) model.predict(inputs)
model.evaluate(inputs, targets) model.evaluate(inputs, targets)
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_distributed_dataset(self, distribution): def test_distributed_dataset(self, distribution):
with distribution.scope(): with distribution.scope():
@ -1700,7 +1701,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
'distributed dataset, you must specify'): 'distributed dataset, you must specify'):
model.fit(ds, epochs=2) model.fit(ds, epochs=2)
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_distributed_datasets_from_function(self, distribution): def test_distributed_datasets_from_function(self, distribution):
with distribution.scope(): with distribution.scope():
@ -1754,7 +1755,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
'distributed dataset, you must specify'): 'distributed dataset, you must specify'):
model.fit(ds, epochs=2) model.fit(ds, epochs=2)
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop(self, distribution): def test_host_training_loop(self, distribution):
with distribution.scope(): with distribution.scope():
@ -1780,7 +1781,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.predict_begin_batches, [0, 10, 20, 30, 40]) self.assertEqual(bc.predict_begin_batches, [0, 10, 20, 30, 40])
self.assertEqual(bc.predict_end_batches, [9, 19, 29, 39, 49]) self.assertEqual(bc.predict_end_batches, [9, 19, 29, 39, 49])
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_last_partial_execution(self, distribution): def test_host_training_loop_last_partial_execution(self, distribution):
with distribution.scope(): with distribution.scope():
@ -1804,7 +1805,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.predict_begin_batches, [0, 20, 40]) self.assertEqual(bc.predict_begin_batches, [0, 20, 40])
self.assertEqual(bc.predict_end_batches, [19, 39, 49]) self.assertEqual(bc.predict_end_batches, [19, 39, 49])
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_dataset_unknown_size(self, distribution): def test_host_training_loop_dataset_unknown_size(self, distribution):
with distribution.scope(): with distribution.scope():
@ -1840,7 +1841,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.predict_begin_batches, [0, 20, 40]) self.assertEqual(bc.predict_begin_batches, [0, 20, 40])
self.assertEqual(bc.predict_end_batches, [19, 39, 49]) self.assertEqual(bc.predict_end_batches, [19, 39, 49])
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_truncate_to_epoch(self, distribution): def test_host_training_loop_truncate_to_epoch(self, distribution):
with distribution.scope(): with distribution.scope():
@ -1866,7 +1867,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.predict_begin_batches, [0]) self.assertEqual(bc.predict_begin_batches, [0])
self.assertEqual(bc.predict_end_batches, [24]) self.assertEqual(bc.predict_end_batches, [24])
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_gradient_clipping(self, distribution): def test_gradient_clipping(self, distribution):
@ -1896,7 +1897,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(self.evaluate(layer.v1), 3.) self.assertAllClose(self.evaluate(layer.v1), 3.)
self.assertAllClose(self.evaluate(layer.v2), -1.) self.assertAllClose(self.evaluate(layer.v2), -1.)
@combinations.generate( @ds_combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager'])) combinations.combine(distribution=all_strategies, mode=['eager']))
def test_custom_gradient_transformation(self, distribution): def test_custom_gradient_transformation(self, distribution):
if isinstance(distribution, if isinstance(distribution,
@ -1929,7 +1930,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(self.evaluate(layer.v1), 0.) self.assertAllClose(self.evaluate(layer.v1), 0.)
self.assertAllClose(self.evaluate(layer.v2), -2.) self.assertAllClose(self.evaluate(layer.v2), -2.)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
all_strategy_combinations_minus_default())) all_strategy_combinations_minus_default()))
def test_distribution_strategy_one_dimensional(self, distribution): def test_distribution_strategy_one_dimensional(self, distribution):
@ -1947,7 +1948,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
model.fit(x, y, epochs=1, steps_per_epoch=2) model.fit(x, y, epochs=1, steps_per_epoch=2)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -1991,7 +1992,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertArrayNear(history.history['loss'], ds_history.history['loss'], self.assertArrayNear(history.history['loss'], ds_history.history['loss'],
1e-5) 1e-5)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
all_strategy_combinations_minus_default())) all_strategy_combinations_minus_default()))
def test_distribution_strategy_with_symbolic_add_loss( def test_distribution_strategy_with_symbolic_add_loss(
@ -2022,7 +2023,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(history.history, ds_history.history) self.assertAllClose(history.history, ds_history.history)
# TODO(omalleyt): Investigate flakiness and re-enable. # TODO(omalleyt): Investigate flakiness and re-enable.
@combinations.generate(all_strategy_minus_default_and_tpu_combinations()) @ds_combinations.generate(all_strategy_minus_default_and_tpu_combinations())
def DISABLED_test_distribution_strategy_with_callable_add_loss( def DISABLED_test_distribution_strategy_with_callable_add_loss(
self, distribution): self, distribution):
@ -2053,7 +2054,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(history.history, ds_history.history) self.assertAllClose(history.history, ds_history.history)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
all_strategy_minus_default_and_tpu_combinations())) all_strategy_minus_default_and_tpu_combinations()))
def test_distribution_strategy_with_add_metric_in_call( def test_distribution_strategy_with_add_metric_in_call(
@ -2101,7 +2102,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(history.history, ds_history.history) self.assertAllClose(history.history, ds_history.history)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.one_device_strategy, strategy_combinations.one_device_strategy,
@ -2155,7 +2156,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(history.history, ds_history.history) self.assertAllClose(history.history, ds_history.history)
@combinations.generate( @ds_combinations.generate(
# TODO(phillypham): Why does validation_steps > 1 not work on TPUs? # TODO(phillypham): Why does validation_steps > 1 not work on TPUs?
combinations.times( combinations.times(
all_strategy_minus_default_and_tpu_combinations())) all_strategy_minus_default_and_tpu_combinations()))
@ -2195,7 +2196,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllClose(history.history, ds_history.history) self.assertAllClose(history.history, ds_history.history)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategies_minus_tpu, distribution=strategies_minus_tpu,
mode=['eager'])) mode=['eager']))
@ -2224,7 +2225,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertAllEqual(output.values, expected_values) self.assertAllEqual(output.values, expected_values)
self.assertAllEqual(output.dense_shape, expected_dense_shape) self.assertAllEqual(output.dense_shape, expected_dense_shape)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategies_minus_tpu, distribution=strategies_minus_tpu,
mode=['eager'])) mode=['eager']))
@ -2251,7 +2252,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
expected_values = [[1], [2, 3]] expected_values = [[1], [2, 3]]
self.assertAllEqual(expected_values, output) self.assertAllEqual(expected_values, output)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategies_minus_default_minus_tpu + tpu_strategies, distribution=strategies_minus_default_minus_tpu + tpu_strategies,
mode=['eager'])) mode=['eager']))
@ -2317,7 +2318,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
for x in dataset: for x in dataset:
train_step(x) train_step(x)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @ds_combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_unimplemented_parameter_server_strategy(self): def test_unimplemented_parameter_server_strategy(self):
cluster_spec = multi_worker_test_base.create_in_process_cluster( cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=2) num_workers=3, num_ps=2)
@ -2433,7 +2434,7 @@ class TestDistributionStrategyWithMultipleAddLossAndMetricCalls(
test.TestCase, parameterized.TestCase): test.TestCase, parameterized.TestCase):
"""Tests complex models with multiple add loss and metric calls.""" """Tests complex models with multiple add loss and metric calls."""
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
all_strategy_combinations_minus_default(), all_strategy_combinations_minus_default(),
combinations.combine( combinations.combine(
@ -2504,7 +2505,7 @@ class DeterministicModel(keras.Model):
class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase): class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
"""Tests that model creation captures the strategy.""" """Tests that model creation captures the strategy."""
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=['eager'])) mode=['eager']))

View File

@ -24,13 +24,13 @@ import numpy as np
import six import six
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.preprocessing import sequence from tensorflow.python.keras.preprocessing import sequence

View File

@ -20,9 +20,10 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.distribute import keras_correctness_test_base
@ -100,12 +101,12 @@ class TestDistributionStrategyDnnCorrectness(
x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
return x_train, y_train, x_eval, y_eval, x_predict return x_train, y_train, x_eval, y_eval, x_predict
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies()) keras_correctness_test_base.test_combinations_with_tpu_strategies())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy, use_numpy,
@ -113,7 +114,7 @@ class TestDistributionStrategyDnnCorrectness(
self.run_correctness_test( self.run_correctness_test(
distribution, use_numpy, use_validation_data, partial_last_batch='eval') distribution, use_numpy, use_validation_data, partial_last_batch='eval')
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager()) .strategy_minus_tpu_and_input_config_combinations_eager())
def test_dnn_correctness_with_partial_last_batch(self, distribution, def test_dnn_correctness_with_partial_last_batch(self, distribution,
@ -127,7 +128,7 @@ class TestDistributionStrategyDnnCorrectness(
partial_last_batch='train_and_eval', partial_last_batch='train_and_eval',
training_epochs=1) training_epochs=1)
@combinations.generate(all_strategy_combinations_with_graph_mode()) @ds_combinations.generate(all_strategy_combinations_with_graph_mode())
def test_dnn_with_dynamic_learning_rate(self, distribution): def test_dnn_with_dynamic_learning_rate(self, distribution):
self.run_dynamic_lr_test(distribution) self.run_dynamic_lr_test(distribution)
@ -166,7 +167,8 @@ class TestDistributionStrategyDnnMetricCorrectness(
history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10)
self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0])
@combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) @ds_combinations.generate(
all_strategy_combinations_with_eager_and_graph_modes())
def test_simple_dnn_metric_correctness(self, distribution): def test_simple_dnn_metric_correctness(self, distribution):
self.run_metric_correctness_test(distribution) self.run_metric_correctness_test(distribution)
@ -214,7 +216,8 @@ class TestDistributionStrategyDnnMetricEvalCorrectness(
self.assertEqual(outs[1], 0.) self.assertEqual(outs[1], 0.)
self.assertEqual(outs[2], 0.) self.assertEqual(outs[2], 0.)
@combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) @ds_combinations.generate(
all_strategy_combinations_with_eager_and_graph_modes())
def test_identity_model_metric_eval_correctness(self, distribution): def test_identity_model_metric_eval_correctness(self, distribution):
self.run_eval_metrics_correctness_test(distribution) self.run_eval_metrics_correctness_test(distribution)
@ -261,7 +264,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
metrics=['mse']) metrics=['mse'])
return model return model
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
if (context.executing_eagerly()) or is_default_strategy(distribution): if (context.executing_eagerly()) or is_default_strategy(distribution):
@ -280,7 +283,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
'`input_dim` set in its first layer or a subclassed model.'): '`input_dim` set in its first layer or a subclassed model.'):
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@combinations.generate(all_strategy_combinations_with_graph_mode()) @ds_combinations.generate(all_strategy_combinations_with_graph_mode())
def test_dnn_with_dynamic_learning_rate(self, distribution): def test_dnn_with_dynamic_learning_rate(self, distribution):
if ((context.executing_eagerly() and not K.is_tpu_strategy(distribution)) or if ((context.executing_eagerly() and not K.is_tpu_strategy(distribution)) or
is_default_strategy(distribution)): is_default_strategy(distribution)):
@ -299,7 +302,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
'`input_dim` set in its first layer or a subclassed model.'): '`input_dim` set in its first layer or a subclassed model.'):
self.run_dynamic_lr_test(distribution) self.run_dynamic_lr_test(distribution)
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies()) keras_correctness_test_base.test_combinations_with_tpu_strategies())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy, use_numpy,

View File

@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -55,7 +55,7 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
metrics=['sparse_categorical_accuracy']) metrics=['sparse_categorical_accuracy'])
return model return model
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model()) keras_correctness_test_base.test_combinations_for_embedding_model())
def test_embedding_model_correctness(self, distribution, use_numpy, def test_embedding_model_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
@ -63,7 +63,7 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
self.use_distributed_dense = False self.use_distributed_dense = False
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model()) keras_correctness_test_base.test_combinations_for_embedding_model())
def test_embedding_time_distributed_model_correctness( def test_embedding_time_distributed_model_correctness(
self, distribution, use_numpy, use_validation_data): self, distribution, use_numpy, use_validation_data):
@ -145,7 +145,7 @@ class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
return x_train, y_train, x_predict return x_train, y_train, x_predict
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model()) keras_correctness_test_base.test_combinations_for_embedding_model())
def test_siamese_embedding_model_correctness(self, distribution, use_numpy, def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):

View File

@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.distribute import keras_correctness_test_base
@ -96,12 +96,12 @@ class DistributionStrategyCnnCorrectnessTest(
x_eval, y_eval = self._get_data(count=1000) x_eval, y_eval = self._get_data(count=1000)
return x_train, y_train, x_eval, y_eval, x_eval return x_train, y_train, x_eval, y_eval, x_eval
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): def test_cnn_correctness(self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
@ -112,7 +112,7 @@ class DistributionStrategyCnnCorrectnessTest(
use_validation_data, use_validation_data,
with_batch_norm='regular') with_batch_norm='regular')
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy, def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
@ -125,7 +125,7 @@ class DistributionStrategyCnnCorrectnessTest(
use_validation_data, use_validation_data,
with_batch_norm='sync') with_batch_norm='sync')
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager()) .strategy_minus_tpu_and_input_config_combinations_eager())
@ -139,7 +139,7 @@ class DistributionStrategyCnnCorrectnessTest(
partial_last_batch=True, partial_last_batch=True,
training_epochs=1) training_epochs=1)
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager()) .strategy_minus_tpu_and_input_config_combinations_eager())

View File

@ -19,9 +19,10 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import metrics from tensorflow.python.keras import metrics
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -114,7 +115,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
if batches_consumed >= 4: # Consume 4 input batches in total. if batches_consumed >= 4: # Consume 4 input batches in total.
break break
@combinations.generate(all_combinations() + tpu_combinations()) @ds_combinations.generate(all_combinations() + tpu_combinations())
def testMean(self, distribution): def testMean(self, distribution):
def _dataset_fn(): def _dataset_fn():
return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch( return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(

View File

@ -22,14 +22,15 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
class KerasModelsTest(test.TestCase, parameterized.TestCase): class KerasModelsTest(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, mode=["eager"])) distribution=strategy_combinations.all_strategies, mode=["eager"]))
def test_lstm_model_with_dynamic_batch(self, distribution): def test_lstm_model_with_dynamic_batch(self, distribution):

View File

@ -21,13 +21,14 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -45,7 +46,7 @@ def get_model():
class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.central_storage_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
@ -101,7 +102,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
# v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25
self.assertAllClose(2.16, self.evaluate(all_vars[2])) self.assertAllClose(2.16, self.evaluate(all_vars[2]))
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.central_storage_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,

View File

@ -20,8 +20,9 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adagrad from tensorflow.python.keras.optimizer_v2 import adagrad
@ -59,7 +60,7 @@ def get_dataset():
class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase): class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_combinations_eager_data_fn()) @ds_combinations.generate(strategy_combinations_eager_data_fn())
def test_linear_model(self, distribution, data_fn): def test_linear_model(self, distribution, data_fn):
with distribution.scope(): with distribution.scope():
model = linear.LinearModel() model = linear.LinearModel()
@ -72,7 +73,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
hist = model.fit(get_dataset(), epochs=5) hist = model.fit(get_dataset(), epochs=5)
self.assertLess(hist.history['loss'][4], 0.2) self.assertLess(hist.history['loss'][4], 0.2)
@combinations.generate(strategy_combinations_eager_data_fn()) @ds_combinations.generate(strategy_combinations_eager_data_fn())
def test_wide_deep_model(self, distribution, data_fn): def test_wide_deep_model(self, distribution, data_fn):
with distribution.scope(): with distribution.scope():
linear_model = linear.LinearModel(units=1) linear_model = linear.LinearModel(units=1)

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
@ -82,7 +82,7 @@ class DistributionStrategyGruModelCorrectnessTest(
else: else:
return rnn_v1.GRU return rnn_v1.GRU
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model()) keras_correctness_test_base.test_combinations_for_embedding_model())
def test_gru_model_correctness(self, distribution, use_numpy, def test_gru_model_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
@ -103,13 +103,13 @@ class DistributionStrategyLstmModelCorrectnessTest(
else: else:
return rnn_v1.LSTM return rnn_v1.LSTM
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model()) keras_correctness_test_base.test_combinations_for_embedding_model())
def test_lstm_model_correctness(self, distribution, use_numpy, def test_lstm_model_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model()) keras_correctness_test_base.test_combinations_for_embedding_model())
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_lstm_model_correctness_mixed_precision(self, distribution, use_numpy, def test_lstm_model_correctness_mixed_precision(self, distribution, use_numpy,

View File

@ -17,9 +17,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import combinations
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.distribute import saved_model_test_base as test_base
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
@ -46,13 +46,13 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase):
return restored_keras_model.predict( return restored_keras_model.predict(
predict_dataset, steps=test_base.PREDICT_STEPS) predict_dataset, steps=test_base.PREDICT_STEPS)
@combinations.generate(test_base.simple_models_with_strategies()) @ds_combinations.generate(test_base.simple_models_with_strategies())
def test_save_no_strategy_restore_strategy(self, model_and_input, def test_save_no_strategy_restore_strategy(self, model_and_input,
distribution): distribution):
self.run_test_save_no_strategy_restore_strategy( self.run_test_save_no_strategy_restore_strategy(
model_and_input, distribution) model_and_input, distribution)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategies(), combinations.times(test_base.simple_models_with_strategies(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_no_strategy(self, model_and_input, def test_save_strategy_restore_no_strategy(self, model_and_input,
@ -60,7 +60,7 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase):
self.run_test_save_strategy_restore_no_strategy( self.run_test_save_strategy_restore_no_strategy(
model_and_input, distribution, save_in_scope) model_and_input, distribution, save_in_scope)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategy_pairs(), combinations.times(test_base.simple_models_with_strategy_pairs(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_strategy(self, model_and_input, def test_save_strategy_restore_strategy(self, model_and_input,

View File

@ -19,8 +19,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -82,7 +83,7 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest(
# TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it
# doesn't work and enable for DistributionStrategy more generally. # doesn't work and enable for DistributionStrategy more generally.
@combinations.generate(test_combinations_for_stateful_embedding_model()) @ds_combinations.generate(test_combinations_for_stateful_embedding_model())
def disabled_test_stateful_lstm_model_correctness( def disabled_test_stateful_lstm_model_correctness(
self, distribution, use_numpy, use_validation_data): self, distribution, use_numpy, use_validation_data):
self.run_correctness_test( self.run_correctness_test(
@ -91,7 +92,7 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest(
use_validation_data, use_validation_data,
is_stateful_model=True) is_stateful_model=True)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_correctness_test_base.test_combinations_with_tpu_strategies())) keras_correctness_test_base.test_combinations_with_tpu_strategies()))
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(

View File

@ -25,13 +25,14 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import losses from tensorflow.python.keras import losses
from tensorflow.python.keras.distribute import distribute_strategy_test as keras_test_lib from tensorflow.python.keras.distribute import distribute_strategy_test as keras_test_lib
from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.distribute import distributed_training_utils
@ -73,7 +74,7 @@ class Counter(keras.callbacks.Callback):
class TestDistributionStrategyWithCallbacks(test.TestCase, class TestDistributionStrategyWithCallbacks(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations())) keras_test_lib.all_strategy_combinations()))
def test_callbacks_in_fit(self, distribution): def test_callbacks_in_fit(self, distribution):
@ -127,7 +128,7 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
'on_train_end': 1 'on_train_end': 1
}) })
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations())) keras_test_lib.all_strategy_combinations()))
def test_callbacks_in_eval(self, distribution): def test_callbacks_in_eval(self, distribution):
@ -151,7 +152,7 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
'on_test_end': 1 'on_test_end': 1
}) })
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations())) keras_test_lib.all_strategy_combinations()))
def test_callbacks_in_predict(self, distribution): def test_callbacks_in_predict(self, distribution):
@ -181,7 +182,7 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -205,7 +206,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs( distributed_training_utils.validate_distributed_dataset_inputs(
distribution, x, y) distribution, x, y)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -229,7 +230,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs( distributed_training_utils.validate_distributed_dataset_inputs(
distribution, x, y) distribution, x, y)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -279,7 +280,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.predict(dataset, verbose=0) model.predict(dataset, verbose=0)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -313,7 +314,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
model.compile( model.compile(
'sgd') 'sgd')
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -340,7 +341,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
model.compile( model.compile(
'sgd') 'sgd')
@combinations.generate( @ds_combinations.generate(
keras_test_lib.all_strategy_combinations_minus_default()) keras_test_lib.all_strategy_combinations_minus_default())
def test_standalone_loss_without_loss_reduction(self, distribution): def test_standalone_loss_without_loss_reduction(self, distribution):
with distribution.scope(): with distribution.scope():
@ -358,7 +359,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase,
# TODO(priyag): Enable all strategies for this test. Currently it does not # TODO(priyag): Enable all strategies for this test. Currently it does not
# work for TPU due to some invalid datatype. # work for TPU due to some invalid datatype.
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -391,7 +392,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase,
class TestDistributionStrategyWithNormalizationLayer(test.TestCase, class TestDistributionStrategyWithNormalizationLayer(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations(), keras_test_lib.all_strategy_combinations(),
combinations.combine( combinations.combine(
@ -436,7 +437,7 @@ class TestDistributionStrategyWithNormalizationLayer(test.TestCase,
class TestDistributionStrategySaveLoadWeights(test.TestCase, class TestDistributionStrategySaveLoadWeights(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations_minus_default(), keras_test_lib.all_strategy_combinations_minus_default(),
combinations.combine( combinations.combine(
@ -463,7 +464,7 @@ class TestDistributionStrategySaveLoadWeights(test.TestCase,
keras_test_lib.get_predict_dataset(distribution), steps=2) keras_test_lib.get_predict_dataset(distribution), steps=2)
model_2.fit(dataset, epochs=1, steps_per_epoch=1) model_2.fit(dataset, epochs=1, steps_per_epoch=1)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations_minus_default(), keras_test_lib.all_strategy_combinations_minus_default(),
combinations.combine( combinations.combine(
@ -498,7 +499,7 @@ class TestDistributionStrategySaveLoadWeights(test.TestCase,
class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase): class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_test_lib.all_strategy_combinations_minus_default())) keras_test_lib.all_strategy_combinations_minus_default()))
def test_layer_outside_scope(self, distribution): def test_layer_outside_scope(self, distribution):
@ -517,7 +518,7 @@ class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase):
loss, loss,
metrics=metrics) metrics=metrics)
@combinations.generate( @ds_combinations.generate(
keras_test_lib.all_strategy_combinations_minus_default()) keras_test_lib.all_strategy_combinations_minus_default())
def test_model_outside_scope(self, distribution): def test_model_outside_scope(self, distribution):
with self.cached_session(): with self.cached_session():
@ -536,7 +537,7 @@ class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase):
class TestDistributionStrategyWithStaticShapes(test.TestCase, class TestDistributionStrategyWithStaticShapes(test.TestCase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
@ -549,7 +550,7 @@ class TestDistributionStrategyWithStaticShapes(test.TestCase,
r'the number of replicas \(2\)'): r'the number of replicas \(2\)'):
keras.layers.Input(shape=(3,), batch_size=5, name='input') keras.layers.Input(shape=(3,), batch_size=5, name='input')
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
import numpy import numpy
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute import strategy_test_lib
@ -32,6 +32,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import optimizer_combinations from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -71,7 +72,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(iterator.initializer) self.evaluate(iterator.initializer)
return iterator return iterator
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
optimizer_combinations.distributions_and_v1_optimizers(), optimizer_combinations.distributions_and_v1_optimizers(),
combinations.combine(mode=["graph"], use_callable_loss=[True, False]) combinations.combine(mode=["graph"], use_callable_loss=[True, False])
@ -122,7 +123,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
self.assertTrue(is_not_increasing) self.assertTrue(is_not_increasing)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
optimizer_combinations.distributions_and_v1_optimizers(), optimizer_combinations.distributions_and_v1_optimizers(),
combinations.combine(mode=["graph"], use_callable_loss=[True, False]) combinations.combine(mode=["graph"], use_callable_loss=[True, False])
@ -161,7 +162,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
self.assertTrue(is_not_increasing) self.assertTrue(is_not_increasing)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
optimizer_combinations.distributions_and_v1_and_v2_optimizers(), optimizer_combinations.distributions_and_v1_and_v2_optimizers(),
combinations.combine(mode=["graph", "eager"])) + combinations.combine( combinations.combine(mode=["graph", "eager"])) + combinations.combine(
@ -228,7 +229,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
get_expected_variables(len(distribution.extended.parameter_devices)), get_expected_variables(len(distribution.extended.parameter_devices)),
set(created_variables)) set(created_variables))
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]),
combinations.times( combinations.times(
@ -295,7 +296,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
combinations.combine(loss_reduction=[ combinations.combine(loss_reduction=[
losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN,
@ -411,7 +412,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# One of the mean loss reductions. # One of the mean loss reductions.
self.assertNear(weight, 2 + 0.053, 0.0001) self.assertNear(weight, 2 + 0.053, 0.0001)
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
optimizer_combinations.distributions_and_v1_and_v2_optimizers(), optimizer_combinations.distributions_and_v1_and_v2_optimizers(),
combinations.combine(mode=["graph", "eager"]), combinations.combine(mode=["graph", "eager"]),
@ -541,7 +542,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
self.assertEqual(initial_loss.shape, loss_tensor.shape) self.assertEqual(initial_loss.shape, loss_tensor.shape)
@combinations.generate( @ds_combinations.generate(
optimizer_combinations.distributions_and_v2_optimizers()) optimizer_combinations.distributions_and_v2_optimizers())
def test_empty_var_list(self, distribution, optimizer_fn): def test_empty_var_list(self, distribution, optimizer_fn):
opt = optimizer_fn() opt = optimizer_fn()

View File

@ -19,12 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -50,7 +50,7 @@ class MiniModel(keras_training.Model):
return self.fc(inputs) return self.fc(inputs)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,

View File

@ -20,12 +20,13 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import core
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -39,11 +40,11 @@ def _mimic_two_cpus():
]) ])
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
combinations.NamedDistribution( ds_combinations.NamedDistribution(
"Collective2CPUs", "Collective2CPUs",
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
lambda: collective_all_reduce_strategy. lambda: collective_all_reduce_strategy.

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import combinations from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import simple_models from tensorflow.python.keras.distribute import simple_models
simple_functional_model = combinations.NamedObject( simple_functional_model = combinations.NamedObject(

View File

@ -24,10 +24,11 @@ import os
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distributed_file_utils from tensorflow.python.distribute import distributed_file_utils
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import callbacks from tensorflow.python.keras import callbacks
from tensorflow.python.keras.distribute import multi_worker_testing_utils from tensorflow.python.keras.distribute import multi_worker_testing_utils
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
@ -79,7 +80,7 @@ def _get_task_config():
class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
mode=['eager'], mode=['eager'],
file_format=['h5', 'tf'], file_format=['h5', 'tf'],
@ -137,7 +138,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, file_format)) args=(self, file_format))
@combinations.generate(combinations.combine(mode=['eager'])) @ds_combinations.generate(combinations.combine(mode=['eager']))
def test_model_checkpoint_works_with_same_file_path(self, mode): def test_model_checkpoint_works_with_same_file_path(self, mode):
def proc_model_checkpoint_works_with_same_file_path( def proc_model_checkpoint_works_with_same_file_path(
@ -163,7 +164,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, saving_filepath)) args=(self, saving_filepath))
@combinations.generate(combinations.combine(mode=['eager'])) @ds_combinations.generate(combinations.combine(mode=['eager']))
def test_backupandrestore_checkpoint_works_with_interruption(self, mode): def test_backupandrestore_checkpoint_works_with_interruption(self, mode):
class InterruptingCallback(callbacks.Callback): class InterruptingCallback(callbacks.Callback):
@ -228,7 +229,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, saving_filepath)) args=(self, saving_filepath))
@combinations.generate(combinations.combine(mode=['eager'])) @ds_combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode): def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
def proc_tensorboard_saves_on_chief_but_not_otherwise(test_obj): def proc_tensorboard_saves_on_chief_but_not_otherwise(test_obj):
@ -266,7 +267,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,)) args=(self,))
@combinations.generate(combinations.combine(mode=['eager'])) @ds_combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode): def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
def proc_tensorboard_can_still_save_to_temp_even_if_it_exists(test_obj): def proc_tensorboard_can_still_save_to_temp_even_if_it_exists(test_obj):
@ -295,7 +296,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,)) args=(self,))
@combinations.generate(combinations.combine(mode=['eager'])) @ds_combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_works_with_same_file_path(self, mode): def test_tensorboard_works_with_same_file_path(self, mode):
def proc_tensorboard_works_with_same_file_path(test_obj, saving_filepath): def proc_tensorboard_works_with_same_file_path(test_obj, saving_filepath):
@ -324,7 +325,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
cluster_spec=test_base.create_cluster_spec(num_workers=2), cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, saving_filepath)) args=(self, saving_filepath))
@combinations.generate(combinations.combine(mode=['eager'])) @ds_combinations.generate(combinations.combine(mode=['eager']))
def test_early_stopping(self, mode): def test_early_stopping(self, mode):
def proc_early_stopping(test_obj): def proc_early_stopping(test_obj):

View File

@ -31,12 +31,13 @@ from absl.testing import parameterized
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks from tensorflow.python.keras import callbacks
from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import metrics as metrics_module
@ -203,7 +204,7 @@ class MultiWorkerVerificationCallback(callbacks.Callback):
class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase, class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
mode=['graph'], mode=['graph'],
strategy_cls=[ strategy_cls=[
@ -261,7 +262,7 @@ class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
self.join_independent_workers(threads_to_join) self.join_independent_workers(threads_to_join)
verification_callback.verify(self) verification_callback.verify(self)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
mode=['graph'], mode=['graph'],
strategy_cls=[ParameterServerStrategy], strategy_cls=[ParameterServerStrategy],

View File

@ -27,10 +27,11 @@ from tensorflow.python import keras
from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.datasets import mnist from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import gradient_descent
@ -56,7 +57,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
else: else:
raise raise
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
mode=['eager'], mode=['eager'],
shard_policy=[None] + list(distribute_options.AutoShardPolicy))) shard_policy=[None] + list(distribute_options.AutoShardPolicy)))

View File

@ -17,9 +17,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations as strategy_combinations_base from tensorflow.python.distribute import strategy_combinations as strategy_combinations_base
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_keras_v2 from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_keras_v2
from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2
from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2

View File

@ -23,9 +23,9 @@ tf.saved_model.save().
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import combinations
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.distribute import saved_model_test_base as test_base
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
@ -54,13 +54,13 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
predict_dataset, predict_dataset,
output_name) output_name)
@combinations.generate(test_base.simple_models_with_strategies()) @ds_combinations.generate(test_base.simple_models_with_strategies())
def test_save_no_strategy_restore_strategy(self, model_and_input, def test_save_no_strategy_restore_strategy(self, model_and_input,
distribution): distribution):
self.run_test_save_no_strategy_restore_strategy( self.run_test_save_no_strategy_restore_strategy(
model_and_input, distribution) model_and_input, distribution)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategies(), combinations.times(test_base.simple_models_with_strategies(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_no_strategy(self, model_and_input, def test_save_strategy_restore_no_strategy(self, model_and_input,
@ -68,7 +68,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
self.run_test_save_strategy_restore_no_strategy( self.run_test_save_strategy_restore_no_strategy(
model_and_input, distribution, save_in_scope) model_and_input, distribution, save_in_scope)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategy_pairs(), combinations.times(test_base.simple_models_with_strategy_pairs(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_strategy(self, model_and_input, def test_save_strategy_restore_strategy(self, model_and_input,

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import model_combinations from tensorflow.python.keras.distribute import model_combinations
from tensorflow.python.keras.distribute import saved_model_test_base as test_base from tensorflow.python.keras.distribute import saved_model_test_base as test_base
@ -54,13 +54,13 @@ class SavedModelKerasModelTest(test_base.TestSavedModelBase):
predict_dataset, predict_dataset,
output_name) output_name)
@combinations.generate(test_base.simple_models_with_strategies()) @ds_combinations.generate(test_base.simple_models_with_strategies())
def test_save_no_strategy_restore_strategy(self, model_and_input, def test_save_no_strategy_restore_strategy(self, model_and_input,
distribution): distribution):
self.run_test_save_no_strategy_restore_strategy( self.run_test_save_no_strategy_restore_strategy(
model_and_input, distribution) model_and_input, distribution)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategies(), combinations.times(test_base.simple_models_with_strategies(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_no_strategy(self, model_and_input, def test_save_strategy_restore_no_strategy(self, model_and_input,
@ -68,7 +68,7 @@ class SavedModelKerasModelTest(test_base.TestSavedModelBase):
self.run_test_save_strategy_restore_no_strategy( self.run_test_save_strategy_restore_no_strategy(
model_and_input, distribution, save_in_scope) model_and_input, distribution, save_in_scope)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategy_pairs(), combinations.times(test_base.simple_models_with_strategy_pairs(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_strategy(self, model_and_input, def test_save_strategy_restore_strategy(self, model_and_input,
@ -80,7 +80,7 @@ class SavedModelKerasModelTest(test_base.TestSavedModelBase):
distribution_for_restoring, distribution_for_restoring,
save_in_scope) save_in_scope)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.simple_models_with_strategies(), combinations.times(test_base.simple_models_with_strategies(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_no_variable_device_placement(self, model_and_input, distribution, def test_no_variable_device_placement(self, model_and_input, distribution,
@ -130,13 +130,13 @@ class SavedModelTFModuleTest(test_base.TestSavedModelBase):
model = saved_model.load(saved_dir) model = saved_model.load(saved_dir)
return self._predict_with_model(distribution, model, predict_dataset) return self._predict_with_model(distribution, model, predict_dataset)
@combinations.generate(test_base.tfmodule_models_with_strategies()) @ds_combinations.generate(test_base.tfmodule_models_with_strategies())
def test_save_no_strategy_restore_strategy(self, model_and_input, def test_save_no_strategy_restore_strategy(self, model_and_input,
distribution): distribution):
self.run_test_save_no_strategy_restore_strategy( self.run_test_save_no_strategy_restore_strategy(
model_and_input, distribution) model_and_input, distribution)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.tfmodule_models_with_strategies(), combinations.times(test_base.tfmodule_models_with_strategies(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_no_strategy( def test_save_strategy_restore_no_strategy(
@ -144,7 +144,7 @@ class SavedModelTFModuleTest(test_base.TestSavedModelBase):
self.run_test_save_strategy_restore_no_strategy( self.run_test_save_strategy_restore_no_strategy(
model_and_input, distribution, save_in_scope) model_and_input, distribution, save_in_scope)
@combinations.generate( @ds_combinations.generate(
combinations.times(test_base.tfmodule_models_with_strategy_pairs(), combinations.times(test_base.tfmodule_models_with_strategy_pairs(),
combinations.combine(save_in_scope=[True, False]))) combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_strategy(self, model_and_input, def test_save_strategy_restore_strategy(self, model_and_input,
@ -156,7 +156,7 @@ class SavedModelTFModuleTest(test_base.TestSavedModelBase):
distribution_for_restoring, distribution_for_restoring,
save_in_scope) save_in_scope)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
model_and_input=[model_combinations.simple_tfmodule_model], model_and_input=[model_combinations.simple_tfmodule_model],
distribution=test_base.strategies + distribution=test_base.strategies +

View File

@ -24,9 +24,9 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import model_combinations from tensorflow.python.keras.distribute import model_combinations
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test

View File

@ -20,10 +20,11 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy import numpy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute.single_loss_example import single_loss_example from tensorflow.python.distribute.single_loss_example import single_loss_example
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.distribute import optimizer_combinations from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -33,7 +34,7 @@ from tensorflow.python.platform import test
@test_util.with_control_flow_v2 @test_util.with_control_flow_v2
class SingleLossStepTest(test.TestCase, parameterized.TestCase): class SingleLossStepTest(test.TestCase, parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
optimizer_combinations.distributions_and_v1_optimizers(), optimizer_combinations.distributions_and_v1_optimizers(),
combinations.combine( combinations.combine(

View File

@ -21,8 +21,9 @@ import os
import sys import sys
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework.errors_impl import NotFoundError from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.keras import callbacks from tensorflow.python.keras import callbacks
from tensorflow.python.keras.distribute import multi_worker_testing_utils from tensorflow.python.keras.distribute import multi_worker_testing_utils
@ -33,7 +34,7 @@ from tensorflow.python.platform import test
class ModelCheckpointTest(test_base.IndependentWorkerTestBase, class ModelCheckpointTest(test_base.IndependentWorkerTestBase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
mode=['graph'], mode=['graph'],
required_gpus=[0, 1], required_gpus=[0, 1],

View File

@ -22,11 +22,12 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.keras.layers.preprocessing import category_crossing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@ -45,7 +46,7 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None):
return dataset.batch(batch_size) return dataset.batch(batch_size)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
# Investigate why crossing is not supported with TPU. # Investigate why crossing is not supported with TPU.
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,

View File

@ -22,11 +22,12 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import category_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@ -45,7 +46,7 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None):
return dataset.batch(batch_size) return dataset.batch(batch_size)
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
# (b/156783625): Outside compilation failed for eager mode only. # (b/156783625): Outside compilation failed for eager mode only.
distribution=strategy_combinations.strategies_minus_tpu, distribution=strategy_combinations.strategies_minus_tpu,

View File

@ -21,16 +21,17 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import discretization from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu, distribution=strategy_combinations.strategies_minus_tpu,
mode=["eager", "graph"])) mode=["eager", "graph"]))

View File

@ -22,17 +22,18 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import hashing from tensorflow.python.keras.layers.preprocessing import hashing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=["eager", "graph"])) mode=["eager", "graph"]))

View File

@ -22,16 +22,17 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=["eager", "graph"])) mode=["eager", "graph"]))

View File

@ -22,11 +22,12 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import index_lookup from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.keras.layers.preprocessing import index_lookup_v1 from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
@ -41,7 +42,7 @@ def get_layer_class():
return index_lookup_v1.IndexLookup return index_lookup_v1.IndexLookup
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=["eager"])) # Eager-only, no graph: b/158793009 mode=["eager"])) # Eager-only, no graph: b/158793009

View File

@ -22,9 +22,10 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import normalization from tensorflow.python.keras.layers.preprocessing import normalization
from tensorflow.python.keras.layers.preprocessing import normalization_v1 from tensorflow.python.keras.layers.preprocessing import normalization_v1
@ -104,7 +105,7 @@ def _get_layer_computation_test_cases():
return crossed_test_cases return crossed_test_cases
@combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,

View File

@ -22,11 +22,12 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.layers.preprocessing import text_vectorization from tensorflow.python.keras.layers.preprocessing import text_vectorization
@ -41,7 +42,7 @@ def get_layer_class():
return text_vectorization_v1.TextVectorization return text_vectorization_v1.TextVectorization
@combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=["eager", "graph"])) mode=["eager", "graph"]))

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
@ -33,6 +33,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -52,14 +53,14 @@ def get_var(val, dtype, name=None):
return variables.VariableV1(val, use_resource=True, dtype=dtype, name=name) return variables.VariableV1(val, use_resource=True, dtype=dtype, name=name)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @ds_combinations.generate(combinations.combine(mode=['graph', 'eager']))
class AutoCastVariableTest(test.TestCase, parameterized.TestCase): class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
strategy_combinations.set_virtual_cpus_to_at_least(3) strategy_combinations.set_virtual_cpus_to_at_least(3)
super(AutoCastVariableTest, self).setUp() super(AutoCastVariableTest, self).setUp()
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_read(self, distribution): def test_read(self, distribution):
with distribution.scope(): with distribution.scope():
x = get_var(1., dtypes.float32) x = get_var(1., dtypes.float32)
@ -103,7 +104,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16) self.assertEqual(x.sparse_read([0]).dtype, dtypes.float16)
self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16) self.assertEqual(x.gather_nd([0]).dtype, dtypes.float16)
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_read_nested_scopes(self, distribution): def test_read_nested_scopes(self, distribution):
with distribution.scope(): with distribution.scope():
x = get_var(1., dtypes.float32) x = get_var(1., dtypes.float32)
@ -123,7 +124,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16)
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_dtype_is_not_string(self, distribution): def test_dtype_is_not_string(self, distribution):
with distribution.scope(): with distribution.scope():
x = get_var(1., dtypes.float32) x = get_var(1., dtypes.float32)
@ -140,7 +141,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(x.true_dtype, dtypes.float32) self.assertEqual(x.true_dtype, dtypes.float32)
self.assertIsInstance(x.true_dtype, dtypes.DType) self.assertIsInstance(x.true_dtype, dtypes.DType)
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_method_delegations(self, distribution): def test_method_delegations(self, distribution):
# Test AutoCastVariable correctly delegates Variable methods to the # Test AutoCastVariable correctly delegates Variable methods to the
# underlying variable. # underlying variable.
@ -220,7 +221,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual( self.assertAllEqual(
evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2]) evaluate(x.scatter_nd_update([[0], [1]], [1., 2.])), [1, 2])
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_operator_overloads(self, distribution): def test_operator_overloads(self, distribution):
with distribution.scope(): with distribution.scope():
for read_dtype in (dtypes.float32, dtypes.float16): for read_dtype in (dtypes.float32, dtypes.float16):
@ -267,7 +268,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(x == [7., 8., 10.], [True, True, False]) self.assertAllEqual(x == [7., 8., 10.], [True, True, False])
self.assertAllEqual(x != [7., 8., 10.], [False, False, True]) self.assertAllEqual(x != [7., 8., 10.], [False, False, True])
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_assign(self, distribution): def test_assign(self, distribution):
with distribution.scope(): with distribution.scope():
x = get_var(0., dtypes.float32) x = get_var(0., dtypes.float32)
@ -344,7 +345,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
# assign still expect float32 value even if in float16 scope # assign still expect float32 value even if in float16 scope
run_and_check() run_and_check()
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_assign_tf_function(self, distribution): def test_assign_tf_function(self, distribution):
if not context.executing_eagerly(): if not context.executing_eagerly():
self.skipTest('Test is not compatible with graph mode') self.skipTest('Test is not compatible with graph mode')
@ -361,7 +362,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
dtypes.float16): dtypes.float16):
self.assertAllClose(5., self.evaluate(run_assign())) self.assertAllClose(5., self.evaluate(run_assign()))
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_assign_op(self, distribution): def test_assign_op(self, distribution):
with distribution.scope(): with distribution.scope():
x = get_var(0., dtypes.float32) x = get_var(0., dtypes.float32)
@ -375,7 +376,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
func() func()
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_tf_function_control_dependencies(self, distribution): def test_tf_function_control_dependencies(self, distribution):
if not context.executing_eagerly(): if not context.executing_eagerly():
self.skipTest('Test is not compatible with graph mode') self.skipTest('Test is not compatible with graph mode')
@ -393,7 +394,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
func() func()
self.assertAllClose(2., self.evaluate(x)) self.assertAllClose(2., self.evaluate(x))
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_assign_stays_in_true_dtype(self, distribution): def test_assign_stays_in_true_dtype(self, distribution):
with distribution.scope(): with distribution.scope():
x = get_var(1., dtypes.float32) x = get_var(1., dtypes.float32)
@ -418,7 +419,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
self.assertEqual(1., self.evaluate(x.value())) self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x)) self.assertEqual(1. + small_val, self.evaluate(x))
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_checkpoint(self, distribution): def test_checkpoint(self, distribution):
with self.test_session(): with self.test_session():
with distribution.scope(): with distribution.scope():
@ -434,7 +435,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
checkpoint.restore(save_path).assert_consumed().run_restore_ops() checkpoint.restore(save_path).assert_consumed().run_restore_ops()
self.assertEqual(self.evaluate(x), 123.) self.assertEqual(self.evaluate(x), 123.)
@combinations.generate(maybe_distribute) @ds_combinations.generate(maybe_distribute)
def test_invalid_wrapped_variable(self, distribution): def test_invalid_wrapped_variable(self, distribution):
with distribution.scope(): with distribution.scope():
# Wrap a non-variable # Wrap a non-variable