Test with MWMS in custom_training_loop_optimizer_test

PiperOrigin-RevId: 324679085
Change-Id: I3a550fde03380d0327906dcc9d449837a5f6cc8b
This commit is contained in:
A. Unique TensorFlower 2020-08-03 14:06:04 -07:00 committed by TensorFlower Gardener
parent a3df6cff1b
commit 729b23995f
2 changed files with 10 additions and 17 deletions

View File

@ -275,7 +275,7 @@ distribute_py_test(
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:test_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras/optimizer_v2",

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import ops
@ -35,14 +35,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.times(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_one_step,
],
distribution=strategy_combinations.multidevice_strategies,
mode=["eager"],
),
combinations.concat(
@ -62,10 +55,10 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
@def_function.function
def optimize():
grads = ops.convert_to_tensor([[1., 1.],
[2., 2.]])
grads = distribution.experimental_distribute_values_from_function(
lambda ctx: grads[ctx.replica_id_in_sync_group])
grads = values.PerReplica([
ops.convert_to_tensor([1., 1.]),
ops.convert_to_tensor([2., 2.]),
])
def step_fn(grads):
optimizer.apply_gradients(
@ -73,8 +66,8 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
experimental_aggregate_gradients=experimental_aggregate_gradients)
return v.read_value()
return test_util.gather(distribution,
distribution.run(step_fn, args=(grads,)))
return distribution.experimental_local_results(
distribution.run(step_fn, args=(grads,)))
self.assertAllClose(optimize(), expected)
@ -125,4 +118,4 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
if __name__ == "__main__":
combinations.main()
test.main()