Deflake cross_device_ops test.

PiperOrigin-RevId: 310039394
Change-Id: I1d0ac7506996d00283323c86de74c6dbbe01bd00
This commit is contained in:
Yuefeng Zhou 2020-05-05 16:04:04 -07:00 committed by TensorFlower Gardener
parent 22546b562d
commit 9d15d75988
2 changed files with 20 additions and 4 deletions

View File

@ -934,8 +934,6 @@ cuda_py_test(
srcs = ["cross_device_ops_test.py"],
tags = [
"multi_and_single_gpu",
"no_oss", # TODO(b/151025792): enable after this is fixed.
"notap", # TODO(b/151025792): enable after this is fixed.
],
deps = [
":collective_all_reduce_strategy",

View File

@ -120,7 +120,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
self.evaluate(ops.convert_to_tensor(left)),
self.evaluate(ops.convert_to_tensor(right)))
def _assert_mirrored_equal(self, left_list, right_list, sess):
def _assert_mirrored_equal(self, left_list, right_list, sess,
run_options=None):
if not isinstance(left_list, list):
left_list, right_list = [left_list], [right_list]
@ -141,7 +142,13 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
# Densify IndexedSlices.
left = [ops.convert_to_tensor(v) for v in left]
right = [ops.convert_to_tensor(v) for v in right]
left, right = sess.run((left, right))
if context.executing_eagerly():
# Optional args in session run are not supported when eager execution
# is enabled.
assert run_options is None
left, right = sess.run((left, right))
else:
left, right = sess.run((left, right), options=run_options)
for left_value, right_value in zip(left, right):
self.assertAllEqual(left_value, right_value)
@ -552,6 +559,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
return (collective_all_reduce_ops, devices,
"grpc://" + self._cluster_spec[task_type][task_id])
def _assert_mirrored_equal(self, left_list, right_list, sess):
if context.executing_eagerly():
run_options = None
else:
# TODO(b/151025792): figure out why missing run options would make the
# test flaky and whether this is a problem in TF 2.
run_options = config_pb2.RunOptions()
run_options.experimental.collective_graph_key = 5
super(CollectiveAllReduceTest, self)._assert_mirrored_equal(
left_list, right_list, sess, run_options=run_options)
def _test_reduction(self,
task_type,
task_id,