Add skip_if_error context manager in test_util.py to conveniently skip errors that are not related to what is being tested.
Fix lingering test flakiness. PiperOrigin-RevId: 310583994 Change-Id: I15925753b6faf9dc5bf3603231b248aa02965c19
This commit is contained in:
parent
b08e6cd85a
commit
546319f28a
@ -452,6 +452,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -160,7 +160,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
for i in range(0, 10):
|
||||
print(
|
||||
'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
|
||||
time.sleep(1)
|
||||
time.sleep(5)
|
||||
|
||||
mpr = multi_process_runner.MultiProcessRunner(
|
||||
proc_func,
|
||||
|
@ -33,10 +33,13 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
NUM_WORKERS = 5
|
||||
|
||||
|
||||
@ -84,9 +87,10 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||
for _ in range(20):
|
||||
worker_step_fn(worker_id)
|
||||
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
|
||||
with test_util.skip_if_error(self, errors_impl.UnavailableError):
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testVariableInitializationWithChangingShape(self, mode):
|
||||
@ -116,9 +120,10 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||
for i in range(20):
|
||||
worker_step_fn(worker_id, num_dims=(i + 1))
|
||||
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
|
||||
with test_util.skip_if_error(self, errors_impl.UnavailableError):
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=NUM_WORKERS))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -460,6 +460,38 @@ def skip_if(condition):
|
||||
return real_skip_if
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_if_error(test_obj, error_type, messages=None):
|
||||
"""Context manager to skip cases not considered failures by the tests.
|
||||
|
||||
Note that this does not work if used in setUpClass/tearDownClass.
|
||||
Usage in setUp/tearDown works fine just like regular test methods.
|
||||
|
||||
Args:
|
||||
test_obj: A test object provided as `self` in the test methods; this object
|
||||
is usually an instance of `unittest.TestCase`'s subclass and should have
|
||||
`skipTest` method.
|
||||
error_type: The error type to skip. Note that if `messages` are given, both
|
||||
`error_type` and `messages` need to match for the test to be skipped.
|
||||
messages: Optional, a string or list of strings. If `None`, the test will be
|
||||
skipped if `error_type` matches what is raised; otherwise, the test is
|
||||
skipped if any of the `messages` is contained in the message of the error
|
||||
raised, and `error_type` matches the error raised.
|
||||
|
||||
Yields:
|
||||
Nothing.
|
||||
"""
|
||||
if messages:
|
||||
messages = nest.flatten(messages)
|
||||
try:
|
||||
yield
|
||||
except error_type as e:
|
||||
if not messages or any([message in str(e) for message in messages]):
|
||||
test_obj.skipTest("Skipping error: {}".format(str(e)))
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def enable_c_shapes(fn):
|
||||
"""No-op. TODO(b/74620627): Remove this."""
|
||||
return fn
|
||||
|
@ -22,6 +22,7 @@ import collections
|
||||
import copy
|
||||
import random
|
||||
import threading
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -808,6 +809,66 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertEqual(tested_codepaths, set(["present", "future"]))
|
||||
|
||||
|
||||
class SkipTestTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _verify_test_in_set_up_or_tear_down(self):
|
||||
with self.assertRaises(unittest.SkipTest):
|
||||
with test_util.skip_if_error(self, ValueError,
|
||||
["foo bar", "test message"]):
|
||||
raise ValueError("test message")
|
||||
try:
|
||||
with self.assertRaisesRegexp(ValueError, "foo bar"):
|
||||
with test_util.skip_if_error(self, ValueError, "test message"):
|
||||
raise ValueError("foo bar")
|
||||
except unittest.SkipTest:
|
||||
raise RuntimeError("Test is not supposed to skip.")
|
||||
|
||||
def setUp(self):
|
||||
super(SkipTestTest, self).setUp()
|
||||
self._verify_test_in_set_up_or_tear_down()
|
||||
|
||||
def tearDown(self):
|
||||
super(SkipTestTest, self).tearDown()
|
||||
self._verify_test_in_set_up_or_tear_down()
|
||||
|
||||
def test_skip_if_error_should_skip(self):
|
||||
with self.assertRaises(unittest.SkipTest):
|
||||
with test_util.skip_if_error(self, ValueError, "test message"):
|
||||
raise ValueError("test message")
|
||||
|
||||
def test_skip_if_error_should_skip_with_list(self):
|
||||
with self.assertRaises(unittest.SkipTest):
|
||||
with test_util.skip_if_error(self, ValueError,
|
||||
["foo bar", "test message"]):
|
||||
raise ValueError("test message")
|
||||
|
||||
def test_skip_if_error_should_skip_without_expected_message(self):
|
||||
with self.assertRaises(unittest.SkipTest):
|
||||
with test_util.skip_if_error(self, ValueError):
|
||||
raise ValueError("test message")
|
||||
|
||||
def test_skip_if_error_should_skip_without_error_message(self):
|
||||
with self.assertRaises(unittest.SkipTest):
|
||||
with test_util.skip_if_error(self, ValueError):
|
||||
raise ValueError()
|
||||
|
||||
def test_skip_if_error_should_raise_message_mismatch(self):
|
||||
try:
|
||||
with self.assertRaisesRegexp(ValueError, "foo bar"):
|
||||
with test_util.skip_if_error(self, ValueError, "test message"):
|
||||
raise ValueError("foo bar")
|
||||
except unittest.SkipTest:
|
||||
raise RuntimeError("Test is not supposed to skip.")
|
||||
|
||||
def test_skip_if_error_should_raise_no_message(self):
|
||||
try:
|
||||
with self.assertRaisesRegexp(ValueError, ""):
|
||||
with test_util.skip_if_error(self, ValueError, "test message"):
|
||||
raise ValueError()
|
||||
except unittest.SkipTest:
|
||||
raise RuntimeError("Test is not supposed to skip.")
|
||||
|
||||
|
||||
# Its own test case to reproduce variable sharing issues which only pop up when
|
||||
# setUp() is overridden and super() is not called.
|
||||
class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
|
||||
|
@ -28,6 +28,8 @@ from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.datasets import mnist
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.platform import test
|
||||
@ -122,10 +124,11 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
||||
steps_per_epoch=70,
|
||||
callbacks=callbacks)
|
||||
|
||||
mpr_result = multi_process_runner.run(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=num_workers),
|
||||
list_stdout=True)
|
||||
with test_util.skip_if_error(self, errors_impl.UnavailableError):
|
||||
mpr_result = multi_process_runner.run(
|
||||
proc_func,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=num_workers),
|
||||
list_stdout=True)
|
||||
|
||||
def extract_accuracy(worker_id, input_string):
|
||||
match = re.match(
|
||||
|
Loading…
Reference in New Issue
Block a user