From 12b6de1bb31e150eb66074219ed54fd2190ec2f7 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 23 Dec 2020 15:03:55 -0800 Subject: [PATCH] Add a test env object to allow passing additional information to worker processes Sometimes we need to prepare the test environment in the main process and pass the information to worker processes. For example to test tf.data service, we need to start the tf.data service cluster in the main process and pass dispatcher address to the worker processes. We hardcode the fields of this object to give us a better idea of its usage. There shouldn't be too many of them. PiperOrigin-RevId: 348852814 Change-Id: Idef7e2659b11623abb46f06139e385047b73cb2f --- tensorflow/python/distribute/combinations.py | 50 +++++++++++++++++-- .../python/distribute/combinations_test.py | 18 +++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index a2e3610b7c5..59e7343df74 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -359,10 +359,50 @@ NamedObject = combinations_lib.NamedObject _running_in_worker = False +def in_main_process(): + """Whether it's in the main test process. + + This is normally used to prepare the test environment which should only happen + in the main process. + + Returns: + A boolean. + """ + return not _running_in_worker + + +class TestEnvironment(object): + + def __init__(self): + self.tf_data_service_dispatcher = None + + def __setattr__(self, name, value): + if not in_main_process(): + raise ValueError( + "combinations.env() should only be modified in the main process. " + "Condition your code on combinations.in_main_process().") + super().__setattr__(name, value) + + +_env = TestEnvironment() + + +def env(): + """Returns the object holds the test environment information. + + Tests should modifies this in the main process if needed, and it will be + passed to the worker processes each time a test case is ran. + + Returns: + a TestEnvironment object. + """ + return _env + + _TestResult = collections.namedtuple("_TestResult", ["status", "message"]) -def _test_runner(test_id): +def _test_runner(test_id, test_env): """Executes the test with the given test_id. This is a simple wrapper around TestRunner to be used with @@ -372,14 +412,16 @@ def _test_runner(test_id): Args: test_id: TestCase.id() + test_env: a TestEnvironment object. Returns: A boolean indicates whether the test succeeds. """ - global _running_in_worker + global _running_in_worker, _env # No need to restore the value of _running_in_worker since it should always be # True in worker processes. _running_in_worker = True + _env = test_env test = unittest.defaultTestLoader.loadTestsFromName(test_id) runner = unittest.TextTestRunner() result = runner.run(test) @@ -453,7 +495,7 @@ def _multi_worker_test(test_method): # [sub process]test_method() test_id = self.id() if runner: - results = runner.run(_test_runner, args=(test_id,)) + results = runner.run(_test_runner, args=(test_id, _env)) else: cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, @@ -461,7 +503,7 @@ def _multi_worker_test(test_method): num_ps=0, has_eval=False) results = multi_process_runner.run( - _test_runner, cluster_spec, args=(test_id,)).return_value + _test_runner, cluster_spec, args=(test_id, _env)).return_value skip_reason = None for result in results: diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py index 02ddcbef632..09028f9266e 100644 --- a/tensorflow/python/distribute/combinations_test.py +++ b/tensorflow/python/distribute/combinations_test.py @@ -97,6 +97,24 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase): self.assertNotEqual(os.getenv("TF_CONFIG"), "") +@combinations.generate(combinations.combine(num_workers=2)) +class ClusterCombinationTestEnvTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + # Note that test case fixtures are executed in both the main process and + # worker processes. + super().setUp() + if combinations.in_main_process(): + combinations.env().tf_data_service_dispatcher = "localhost" + + def testTfDataServiceDispatcher(self): + self.assertEqual(combinations.env().tf_data_service_dispatcher, "localhost") + + def testUpdateEnvInWorker(self): + with self.assertRaises(ValueError): + combinations.env().tf_data_service_dispatcher = "localhost" + + # unittest.expectedFailure doesn't work with parameterized test methods, so we # have to decorate the class instead. @unittest.expectedFailure