Test with MWMS in custom_training_loop_optimizer_test
PiperOrigin-RevId: 324679085 Change-Id: I3a550fde03380d0327906dcc9d449837a5f6cc8b
This commit is contained in:
parent
a3df6cff1b
commit
729b23995f
@ -275,7 +275,7 @@ distribute_py_test(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
|
||||
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import ops
|
||||
@ -35,14 +35,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_one_step,
|
||||
],
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
mode=["eager"],
|
||||
),
|
||||
combinations.concat(
|
||||
@ -62,10 +55,10 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def optimize():
|
||||
grads = ops.convert_to_tensor([[1., 1.],
|
||||
[2., 2.]])
|
||||
grads = distribution.experimental_distribute_values_from_function(
|
||||
lambda ctx: grads[ctx.replica_id_in_sync_group])
|
||||
grads = values.PerReplica([
|
||||
ops.convert_to_tensor([1., 1.]),
|
||||
ops.convert_to_tensor([2., 2.]),
|
||||
])
|
||||
|
||||
def step_fn(grads):
|
||||
optimizer.apply_gradients(
|
||||
@ -73,8 +66,8 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
experimental_aggregate_gradients=experimental_aggregate_gradients)
|
||||
return v.read_value()
|
||||
|
||||
return test_util.gather(distribution,
|
||||
distribution.run(step_fn, args=(grads,)))
|
||||
return distribution.experimental_local_results(
|
||||
distribution.run(step_fn, args=(grads,)))
|
||||
|
||||
self.assertAllClose(optimize(), expected)
|
||||
|
||||
@ -125,4 +118,4 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user