Deflake cross_device_ops test.
PiperOrigin-RevId: 310039394 Change-Id: I1d0ac7506996d00283323c86de74c6dbbe01bd00
This commit is contained in:
parent
22546b562d
commit
9d15d75988
@ -934,8 +934,6 @@ cuda_py_test(
|
||||
srcs = ["cross_device_ops_test.py"],
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_oss", # TODO(b/151025792): enable after this is fixed.
|
||||
"notap", # TODO(b/151025792): enable after this is fixed.
|
||||
],
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
|
@ -120,7 +120,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(ops.convert_to_tensor(left)),
|
||||
self.evaluate(ops.convert_to_tensor(right)))
|
||||
|
||||
def _assert_mirrored_equal(self, left_list, right_list, sess):
|
||||
def _assert_mirrored_equal(self, left_list, right_list, sess,
|
||||
run_options=None):
|
||||
if not isinstance(left_list, list):
|
||||
left_list, right_list = [left_list], [right_list]
|
||||
|
||||
@ -141,7 +142,13 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
||||
# Densify IndexedSlices.
|
||||
left = [ops.convert_to_tensor(v) for v in left]
|
||||
right = [ops.convert_to_tensor(v) for v in right]
|
||||
if context.executing_eagerly():
|
||||
# Optional args in session run are not supported when eager execution
|
||||
# is enabled.
|
||||
assert run_options is None
|
||||
left, right = sess.run((left, right))
|
||||
else:
|
||||
left, right = sess.run((left, right), options=run_options)
|
||||
for left_value, right_value in zip(left, right):
|
||||
self.assertAllEqual(left_value, right_value)
|
||||
|
||||
@ -552,6 +559,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
return (collective_all_reduce_ops, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
|
||||
def _assert_mirrored_equal(self, left_list, right_list, sess):
|
||||
if context.executing_eagerly():
|
||||
run_options = None
|
||||
else:
|
||||
# TODO(b/151025792): figure out why missing run options would make the
|
||||
# test flaky and whether this is a problem in TF 2.
|
||||
run_options = config_pb2.RunOptions()
|
||||
run_options.experimental.collective_graph_key = 5
|
||||
super(CollectiveAllReduceTest, self)._assert_mirrored_equal(
|
||||
left_list, right_list, sess, run_options=run_options)
|
||||
|
||||
def _test_reduction(self,
|
||||
task_type,
|
||||
task_id,
|
||||
|
Loading…
Reference in New Issue
Block a user