Test with MWMS in custom_training_loop_optimizer_test
PiperOrigin-RevId: 324654618 Change-Id: I1a43c1b275bdb9b084432220519044fe4a69e1a8
This commit is contained in:
parent
7a4b89dbaf
commit
3b83a25110
@ -275,7 +275,7 @@ distribute_py_test(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import ops
|
||||
@ -35,7 +35,14 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.tpu_strategy_one_step,
|
||||
],
|
||||
mode=["eager"],
|
||||
),
|
||||
combinations.concat(
|
||||
@ -55,10 +62,10 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def optimize():
|
||||
grads = values.PerReplica([
|
||||
ops.convert_to_tensor([1., 1.]),
|
||||
ops.convert_to_tensor([2., 2.]),
|
||||
])
|
||||
grads = ops.convert_to_tensor([[1., 1.],
|
||||
[2., 2.]])
|
||||
grads = distribution.experimental_distribute_values_from_function(
|
||||
lambda ctx: grads[ctx.replica_id_in_sync_group])
|
||||
|
||||
def step_fn(grads):
|
||||
optimizer.apply_gradients(
|
||||
@ -66,8 +73,8 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
experimental_aggregate_gradients=experimental_aggregate_gradients)
|
||||
return v.read_value()
|
||||
|
||||
return distribution.experimental_local_results(
|
||||
distribution.run(step_fn, args=(grads,)))
|
||||
return test_util.gather(distribution,
|
||||
distribution.run(step_fn, args=(grads,)))
|
||||
|
||||
self.assertAllClose(optimize(), expected)
|
||||
|
||||
@ -118,4 +125,4 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
combinations.main()
|
||||
|
Loading…
Reference in New Issue
Block a user