Deflake cross_device_ops test.
PiperOrigin-RevId: 310039394 Change-Id: I1d0ac7506996d00283323c86de74c6dbbe01bd00
This commit is contained in:
parent
22546b562d
commit
9d15d75988
@ -934,8 +934,6 @@ cuda_py_test(
|
|||||||
srcs = ["cross_device_ops_test.py"],
|
srcs = ["cross_device_ops_test.py"],
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_oss", # TODO(b/151025792): enable after this is fixed.
|
|
||||||
"notap", # TODO(b/151025792): enable after this is fixed.
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":collective_all_reduce_strategy",
|
":collective_all_reduce_strategy",
|
||||||
|
@ -120,7 +120,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
|||||||
self.evaluate(ops.convert_to_tensor(left)),
|
self.evaluate(ops.convert_to_tensor(left)),
|
||||||
self.evaluate(ops.convert_to_tensor(right)))
|
self.evaluate(ops.convert_to_tensor(right)))
|
||||||
|
|
||||||
def _assert_mirrored_equal(self, left_list, right_list, sess):
|
def _assert_mirrored_equal(self, left_list, right_list, sess,
|
||||||
|
run_options=None):
|
||||||
if not isinstance(left_list, list):
|
if not isinstance(left_list, list):
|
||||||
left_list, right_list = [left_list], [right_list]
|
left_list, right_list = [left_list], [right_list]
|
||||||
|
|
||||||
@ -141,7 +142,13 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
|
|||||||
# Densify IndexedSlices.
|
# Densify IndexedSlices.
|
||||||
left = [ops.convert_to_tensor(v) for v in left]
|
left = [ops.convert_to_tensor(v) for v in left]
|
||||||
right = [ops.convert_to_tensor(v) for v in right]
|
right = [ops.convert_to_tensor(v) for v in right]
|
||||||
|
if context.executing_eagerly():
|
||||||
|
# Optional args in session run are not supported when eager execution
|
||||||
|
# is enabled.
|
||||||
|
assert run_options is None
|
||||||
left, right = sess.run((left, right))
|
left, right = sess.run((left, right))
|
||||||
|
else:
|
||||||
|
left, right = sess.run((left, right), options=run_options)
|
||||||
for left_value, right_value in zip(left, right):
|
for left_value, right_value in zip(left, right):
|
||||||
self.assertAllEqual(left_value, right_value)
|
self.assertAllEqual(left_value, right_value)
|
||||||
|
|
||||||
@ -552,6 +559,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
|||||||
return (collective_all_reduce_ops, devices,
|
return (collective_all_reduce_ops, devices,
|
||||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||||
|
|
||||||
|
def _assert_mirrored_equal(self, left_list, right_list, sess):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
run_options = None
|
||||||
|
else:
|
||||||
|
# TODO(b/151025792): figure out why missing run options would make the
|
||||||
|
# test flaky and whether this is a problem in TF 2.
|
||||||
|
run_options = config_pb2.RunOptions()
|
||||||
|
run_options.experimental.collective_graph_key = 5
|
||||||
|
super(CollectiveAllReduceTest, self)._assert_mirrored_equal(
|
||||||
|
left_list, right_list, sess, run_options=run_options)
|
||||||
|
|
||||||
def _test_reduction(self,
|
def _test_reduction(self,
|
||||||
task_type,
|
task_type,
|
||||||
task_id,
|
task_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user