Test with MWMS in custom_training_loop_optimizer_test

PiperOrigin-RevId: 324654618
Change-Id: I1a43c1b275bdb9b084432220519044fe4a69e1a8
This commit is contained in:
Ran Chen 2020-08-03 12:12:20 -07:00 committed by TensorFlower Gardener
parent 7a4b89dbaf
commit 3b83a25110
2 changed files with 17 additions and 10 deletions

View File

@ -275,7 +275,7 @@ distribute_py_test(
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:values",
"//tensorflow/python/distribute:test_util",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras/optimizer_v2",

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import ops
@ -35,7 +35,14 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.times(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_one_step,
],
mode=["eager"],
),
combinations.concat(
@ -55,10 +62,10 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
@def_function.function
def optimize():
grads = values.PerReplica([
ops.convert_to_tensor([1., 1.]),
ops.convert_to_tensor([2., 2.]),
])
grads = ops.convert_to_tensor([[1., 1.],
[2., 2.]])
grads = distribution.experimental_distribute_values_from_function(
lambda ctx: grads[ctx.replica_id_in_sync_group])
def step_fn(grads):
optimizer.apply_gradients(
@ -66,8 +73,8 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
experimental_aggregate_gradients=experimental_aggregate_gradients)
return v.read_value()
return distribution.experimental_local_results(
distribution.run(step_fn, args=(grads,)))
return test_util.gather(distribution,
distribution.run(step_fn, args=(grads,)))
self.assertAllClose(optimize(), expected)
@ -118,4 +125,4 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
if __name__ == "__main__":
test.main()
combinations.main()