Add a test env object to allow passing additional information to worker
processes Sometimes we need to prepare the test environment in the main process and pass the information to worker processes. For example to test tf.data service, we need to start the tf.data service cluster in the main process and pass dispatcher address to the worker processes. We hardcode the fields of this object to give us a better idea of its usage. There shouldn't be too many of them. PiperOrigin-RevId: 348852814 Change-Id: Idef7e2659b11623abb46f06139e385047b73cb2f
This commit is contained in:
parent
bbcc57de29
commit
12b6de1bb3
@ -359,10 +359,50 @@ NamedObject = combinations_lib.NamedObject
|
||||
_running_in_worker = False
|
||||
|
||||
|
||||
def in_main_process():
|
||||
"""Whether it's in the main test process.
|
||||
|
||||
This is normally used to prepare the test environment which should only happen
|
||||
in the main process.
|
||||
|
||||
Returns:
|
||||
A boolean.
|
||||
"""
|
||||
return not _running_in_worker
|
||||
|
||||
|
||||
class TestEnvironment(object):
|
||||
|
||||
def __init__(self):
|
||||
self.tf_data_service_dispatcher = None
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if not in_main_process():
|
||||
raise ValueError(
|
||||
"combinations.env() should only be modified in the main process. "
|
||||
"Condition your code on combinations.in_main_process().")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
_env = TestEnvironment()
|
||||
|
||||
|
||||
def env():
|
||||
"""Returns the object holds the test environment information.
|
||||
|
||||
Tests should modifies this in the main process if needed, and it will be
|
||||
passed to the worker processes each time a test case is ran.
|
||||
|
||||
Returns:
|
||||
a TestEnvironment object.
|
||||
"""
|
||||
return _env
|
||||
|
||||
|
||||
_TestResult = collections.namedtuple("_TestResult", ["status", "message"])
|
||||
|
||||
|
||||
def _test_runner(test_id):
|
||||
def _test_runner(test_id, test_env):
|
||||
"""Executes the test with the given test_id.
|
||||
|
||||
This is a simple wrapper around TestRunner to be used with
|
||||
@ -372,14 +412,16 @@ def _test_runner(test_id):
|
||||
|
||||
Args:
|
||||
test_id: TestCase.id()
|
||||
test_env: a TestEnvironment object.
|
||||
|
||||
Returns:
|
||||
A boolean indicates whether the test succeeds.
|
||||
"""
|
||||
global _running_in_worker
|
||||
global _running_in_worker, _env
|
||||
# No need to restore the value of _running_in_worker since it should always be
|
||||
# True in worker processes.
|
||||
_running_in_worker = True
|
||||
_env = test_env
|
||||
test = unittest.defaultTestLoader.loadTestsFromName(test_id)
|
||||
runner = unittest.TextTestRunner()
|
||||
result = runner.run(test)
|
||||
@ -453,7 +495,7 @@ def _multi_worker_test(test_method):
|
||||
# [sub process]test_method()
|
||||
test_id = self.id()
|
||||
if runner:
|
||||
results = runner.run(_test_runner, args=(test_id,))
|
||||
results = runner.run(_test_runner, args=(test_id, _env))
|
||||
else:
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(
|
||||
has_chief=has_chief,
|
||||
@ -461,7 +503,7 @@ def _multi_worker_test(test_method):
|
||||
num_ps=0,
|
||||
has_eval=False)
|
||||
results = multi_process_runner.run(
|
||||
_test_runner, cluster_spec, args=(test_id,)).return_value
|
||||
_test_runner, cluster_spec, args=(test_id, _env)).return_value
|
||||
|
||||
skip_reason = None
|
||||
for result in results:
|
||||
|
@ -97,6 +97,24 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertNotEqual(os.getenv("TF_CONFIG"), "")
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(num_workers=2))
|
||||
class ClusterCombinationTestEnvTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Note that test case fixtures are executed in both the main process and
|
||||
# worker processes.
|
||||
super().setUp()
|
||||
if combinations.in_main_process():
|
||||
combinations.env().tf_data_service_dispatcher = "localhost"
|
||||
|
||||
def testTfDataServiceDispatcher(self):
|
||||
self.assertEqual(combinations.env().tf_data_service_dispatcher, "localhost")
|
||||
|
||||
def testUpdateEnvInWorker(self):
|
||||
with self.assertRaises(ValueError):
|
||||
combinations.env().tf_data_service_dispatcher = "localhost"
|
||||
|
||||
|
||||
# unittest.expectedFailure doesn't work with parameterized test methods, so we
|
||||
# have to decorate the class instead.
|
||||
@unittest.expectedFailure
|
||||
|
Loading…
Reference in New Issue
Block a user