Add a test env object to allow passing additional information to worker

processes

Sometimes we need to prepare the test environment in the main process and pass the information to worker processes. For example to test tf.data service, we need to start the tf.data service cluster in the main process and pass dispatcher address to the worker processes.

We hardcode the fields of this object to give us a better idea of its usage. There shouldn't be too many of them.

PiperOrigin-RevId: 348852814
Change-Id: Idef7e2659b11623abb46f06139e385047b73cb2f
This commit is contained in:
Ran Chen 2020-12-23 15:03:55 -08:00 committed by TensorFlower Gardener
parent bbcc57de29
commit 12b6de1bb3
2 changed files with 64 additions and 4 deletions

View File

@ -359,10 +359,50 @@ NamedObject = combinations_lib.NamedObject
_running_in_worker = False
def in_main_process():
"""Whether it's in the main test process.
This is normally used to prepare the test environment which should only happen
in the main process.
Returns:
A boolean.
"""
return not _running_in_worker
class TestEnvironment(object):
def __init__(self):
self.tf_data_service_dispatcher = None
def __setattr__(self, name, value):
if not in_main_process():
raise ValueError(
"combinations.env() should only be modified in the main process. "
"Condition your code on combinations.in_main_process().")
super().__setattr__(name, value)
_env = TestEnvironment()
def env():
"""Returns the object holds the test environment information.
Tests should modifies this in the main process if needed, and it will be
passed to the worker processes each time a test case is ran.
Returns:
a TestEnvironment object.
"""
return _env
_TestResult = collections.namedtuple("_TestResult", ["status", "message"])
def _test_runner(test_id):
def _test_runner(test_id, test_env):
"""Executes the test with the given test_id.
This is a simple wrapper around TestRunner to be used with
@ -372,14 +412,16 @@ def _test_runner(test_id):
Args:
test_id: TestCase.id()
test_env: a TestEnvironment object.
Returns:
A boolean indicates whether the test succeeds.
"""
global _running_in_worker
global _running_in_worker, _env
# No need to restore the value of _running_in_worker since it should always be
# True in worker processes.
_running_in_worker = True
_env = test_env
test = unittest.defaultTestLoader.loadTestsFromName(test_id)
runner = unittest.TextTestRunner()
result = runner.run(test)
@ -453,7 +495,7 @@ def _multi_worker_test(test_method):
# [sub process]test_method()
test_id = self.id()
if runner:
results = runner.run(_test_runner, args=(test_id,))
results = runner.run(_test_runner, args=(test_id, _env))
else:
cluster_spec = multi_worker_test_base.create_cluster_spec(
has_chief=has_chief,
@ -461,7 +503,7 @@ def _multi_worker_test(test_method):
num_ps=0,
has_eval=False)
results = multi_process_runner.run(
_test_runner, cluster_spec, args=(test_id,)).return_value
_test_runner, cluster_spec, args=(test_id, _env)).return_value
skip_reason = None
for result in results:

View File

@ -97,6 +97,24 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
self.assertNotEqual(os.getenv("TF_CONFIG"), "")
@combinations.generate(combinations.combine(num_workers=2))
class ClusterCombinationTestEnvTest(test.TestCase, parameterized.TestCase):
def setUp(self):
# Note that test case fixtures are executed in both the main process and
# worker processes.
super().setUp()
if combinations.in_main_process():
combinations.env().tf_data_service_dispatcher = "localhost"
def testTfDataServiceDispatcher(self):
self.assertEqual(combinations.env().tf_data_service_dispatcher, "localhost")
def testUpdateEnvInWorker(self):
with self.assertRaises(ValueError):
combinations.env().tf_data_service_dispatcher = "localhost"
# unittest.expectedFailure doesn't work with parameterized test methods, so we
# have to decorate the class instead.
@unittest.expectedFailure