From 9d15d759881ed8a2490ad861b268a8cbbf4c8a83 Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Tue, 5 May 2020 16:04:04 -0700 Subject: [PATCH] Deflake cross_device_ops test. PiperOrigin-RevId: 310039394 Change-Id: I1d0ac7506996d00283323c86de74c6dbbe01bd00 --- tensorflow/python/distribute/BUILD | 2 -- .../distribute/cross_device_ops_test.py | 22 +++++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index ca4e302073a..5dccb47fb19 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -934,8 +934,6 @@ cuda_py_test( srcs = ["cross_device_ops_test.py"], tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/151025792): enable after this is fixed. - "notap", # TODO(b/151025792): enable after this is fixed. ], deps = [ ":collective_all_reduce_strategy", diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 7f25066a45f..e1aa2bea97c 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -120,7 +120,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): self.evaluate(ops.convert_to_tensor(left)), self.evaluate(ops.convert_to_tensor(right))) - def _assert_mirrored_equal(self, left_list, right_list, sess): + def _assert_mirrored_equal(self, left_list, right_list, sess, + run_options=None): if not isinstance(left_list, list): left_list, right_list = [left_list], [right_list] @@ -141,7 +142,13 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # Densify IndexedSlices. left = [ops.convert_to_tensor(v) for v in left] right = [ops.convert_to_tensor(v) for v in right] - left, right = sess.run((left, right)) + if context.executing_eagerly(): + # Optional args in session run are not supported when eager execution + # is enabled. + assert run_options is None + left, right = sess.run((left, right)) + else: + left, right = sess.run((left, right), options=run_options) for left_value, right_value in zip(left, right): self.assertAllEqual(left_value, right_value) @@ -552,6 +559,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, return (collective_all_reduce_ops, devices, "grpc://" + self._cluster_spec[task_type][task_id]) + def _assert_mirrored_equal(self, left_list, right_list, sess): + if context.executing_eagerly(): + run_options = None + else: + # TODO(b/151025792): figure out why missing run options would make the + # test flaky and whether this is a problem in TF 2. + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 5 + super(CollectiveAllReduceTest, self)._assert_mirrored_equal( + left_list, right_list, sess, run_options=run_options) + def _test_reduction(self, task_type, task_id,