Add skip_if_error context manager in test_util.py to conveniently skip errors that are not related to what is being tested.

Fix lingering test flakiness.

PiperOrigin-RevId: 310583994
Change-Id: I15925753b6faf9dc5bf3603231b248aa02965c19
This commit is contained in:
Rick Chao 2020-05-08 10:36:54 -07:00 committed by TensorFlower Gardener
parent b08e6cd85a
commit 546319f28a
6 changed files with 113 additions and 11 deletions

View File

@ -452,6 +452,7 @@ cuda_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",

View File

@ -160,7 +160,7 @@ class MultiProcessRunnerTest(test.TestCase):
for i in range(0, 10):
print(
'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
time.sleep(1)
time.sleep(5)
mpr = multi_process_runner.MultiProcessRunner(
proc_func,

View File

@ -33,10 +33,13 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
NUM_WORKERS = 5
@ -84,9 +87,10 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
for _ in range(20):
worker_step_fn(worker_id)
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
with test_util.skip_if_error(self, errors_impl.UnavailableError):
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
@combinations.generate(combinations.combine(mode=['eager']))
def testVariableInitializationWithChangingShape(self, mode):
@ -116,9 +120,10 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
for i in range(20):
worker_step_fn(worker_id, num_dims=(i + 1))
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
with test_util.skip_if_error(self, errors_impl.UnavailableError):
multi_process_runner.run(
worker_fn,
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
if __name__ == '__main__':

View File

@ -460,6 +460,38 @@ def skip_if(condition):
return real_skip_if
@contextlib.contextmanager
def skip_if_error(test_obj, error_type, messages=None):
"""Context manager to skip cases not considered failures by the tests.
Note that this does not work if used in setUpClass/tearDownClass.
Usage in setUp/tearDown works fine just like regular test methods.
Args:
test_obj: A test object provided as `self` in the test methods; this object
is usually an instance of `unittest.TestCase`'s subclass and should have
`skipTest` method.
error_type: The error type to skip. Note that if `messages` are given, both
`error_type` and `messages` need to match for the test to be skipped.
messages: Optional, a string or list of strings. If `None`, the test will be
skipped if `error_type` matches what is raised; otherwise, the test is
skipped if any of the `messages` is contained in the message of the error
raised, and `error_type` matches the error raised.
Yields:
Nothing.
"""
if messages:
messages = nest.flatten(messages)
try:
yield
except error_type as e:
if not messages or any([message in str(e) for message in messages]):
test_obj.skipTest("Skipping error: {}".format(str(e)))
else:
raise
def enable_c_shapes(fn):
"""No-op. TODO(b/74620627): Remove this."""
return fn

View File

@ -22,6 +22,7 @@ import collections
import copy
import random
import threading
import unittest
import weakref
from absl.testing import parameterized
@ -808,6 +809,66 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertEqual(tested_codepaths, set(["present", "future"]))
class SkipTestTest(test_util.TensorFlowTestCase):
def _verify_test_in_set_up_or_tear_down(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError,
["foo bar", "test message"]):
raise ValueError("test message")
try:
with self.assertRaisesRegexp(ValueError, "foo bar"):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError("foo bar")
except unittest.SkipTest:
raise RuntimeError("Test is not supposed to skip.")
def setUp(self):
super(SkipTestTest, self).setUp()
self._verify_test_in_set_up_or_tear_down()
def tearDown(self):
super(SkipTestTest, self).tearDown()
self._verify_test_in_set_up_or_tear_down()
def test_skip_if_error_should_skip(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError("test message")
def test_skip_if_error_should_skip_with_list(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError,
["foo bar", "test message"]):
raise ValueError("test message")
def test_skip_if_error_should_skip_without_expected_message(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError):
raise ValueError("test message")
def test_skip_if_error_should_skip_without_error_message(self):
with self.assertRaises(unittest.SkipTest):
with test_util.skip_if_error(self, ValueError):
raise ValueError()
def test_skip_if_error_should_raise_message_mismatch(self):
try:
with self.assertRaisesRegexp(ValueError, "foo bar"):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError("foo bar")
except unittest.SkipTest:
raise RuntimeError("Test is not supposed to skip.")
def test_skip_if_error_should_raise_no_message(self):
try:
with self.assertRaisesRegexp(ValueError, ""):
with test_util.skip_if_error(self, ValueError, "test message"):
raise ValueError()
except unittest.SkipTest:
raise RuntimeError("Test is not supposed to skip.")
# Its own test case to reproduce variable sharing issues which only pop up when
# setUp() is overridden and super() is not called.
class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):

View File

@ -28,6 +28,8 @@ from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test
@ -122,10 +124,11 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
steps_per_epoch=70,
callbacks=callbacks)
mpr_result = multi_process_runner.run(
proc_func,
multi_worker_test_base.create_cluster_spec(num_workers=num_workers),
list_stdout=True)
with test_util.skip_if_error(self, errors_impl.UnavailableError):
mpr_result = multi_process_runner.run(
proc_func,
multi_worker_test_base.create_cluster_spec(num_workers=num_workers),
list_stdout=True)
def extract_accuracy(worker_id, input_string):
match = re.match(