Test with MWMS in custom_training_loop_optimizer_test
PiperOrigin-RevId: 324654618 Change-Id: I1a43c1b275bdb9b084432220519044fe4a69e1a8
This commit is contained in:
parent
7a4b89dbaf
commit
3b83a25110
@ -275,7 +275,7 @@ distribute_py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/distribute:combinations",
|
"//tensorflow/python/distribute:combinations",
|
||||||
"//tensorflow/python/distribute:strategy_combinations",
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
"//tensorflow/python/distribute:values",
|
"//tensorflow/python/distribute:test_util",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
"//tensorflow/python/keras/optimizer_v2",
|
"//tensorflow/python/keras/optimizer_v2",
|
||||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
|||||||
|
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import test_util
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -35,7 +35,14 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=strategy_combinations.multidevice_strategies,
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.tpu_strategy_one_step,
|
||||||
|
],
|
||||||
mode=["eager"],
|
mode=["eager"],
|
||||||
),
|
),
|
||||||
combinations.concat(
|
combinations.concat(
|
||||||
@ -55,10 +62,10 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def optimize():
|
def optimize():
|
||||||
grads = values.PerReplica([
|
grads = ops.convert_to_tensor([[1., 1.],
|
||||||
ops.convert_to_tensor([1., 1.]),
|
[2., 2.]])
|
||||||
ops.convert_to_tensor([2., 2.]),
|
grads = distribution.experimental_distribute_values_from_function(
|
||||||
])
|
lambda ctx: grads[ctx.replica_id_in_sync_group])
|
||||||
|
|
||||||
def step_fn(grads):
|
def step_fn(grads):
|
||||||
optimizer.apply_gradients(
|
optimizer.apply_gradients(
|
||||||
@ -66,7 +73,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
experimental_aggregate_gradients=experimental_aggregate_gradients)
|
experimental_aggregate_gradients=experimental_aggregate_gradients)
|
||||||
return v.read_value()
|
return v.read_value()
|
||||||
|
|
||||||
return distribution.experimental_local_results(
|
return test_util.gather(distribution,
|
||||||
distribution.run(step_fn, args=(grads,)))
|
distribution.run(step_fn, args=(grads,)))
|
||||||
|
|
||||||
self.assertAllClose(optimize(), expected)
|
self.assertAllClose(optimize(), expected)
|
||||||
@ -118,4 +125,4 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
combinations.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user