MultiProcessRunner: symbol replacement: barrier->get_barrier

PiperOrigin-RevId: 335901962
Change-Id: I55ac18bb4a6ac52ca22ac06c329d02f4fdfe083c
This commit is contained in:
Rick Chao 2020-10-07 10:41:31 -07:00 committed by TensorFlower Gardener
parent e0ed4b42ee
commit cb1786e165
7 changed files with 17 additions and 16 deletions

View File

@ -1145,11 +1145,12 @@ def run(fn,
_barrier = None _barrier = None
def barrier(): def get_barrier():
if _barrier is None: if _barrier is None:
raise ValueError( raise ValueError(
'barrier is not defined. It is likely because you are calling barrier()' 'barrier is not defined. It is likely because you are calling '
'in the main process. barrier() can only be called in the subprocesses.' 'get_barrier() in the main process. get_barrier() can only be called '
'in the subprocesses.'
) )
return _barrier return _barrier

View File

@ -54,7 +54,7 @@ def fn_that_returns_args_and_kwargs(*args, **kwargs):
def fn_with_barrier(): def fn_with_barrier():
return multi_process_runner.barrier() return multi_process_runner.get_barrier()
def fn_that_returns_pid(): def fn_that_returns_pid():
@ -296,7 +296,7 @@ class MultiProcessRunnerTest(test.TestCase):
def test_barrier_called_in_main_process(self): def test_barrier_called_in_main_process(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
multi_process_runner.barrier() multi_process_runner.get_barrier()
def test_stdout_available_when_timeout(self): def test_stdout_available_when_timeout(self):

View File

@ -74,7 +74,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
def worker_step_fn(worker_id): def worker_step_fn(worker_id):
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
# Make sure the processeses are in sync after updating the cluster # Make sure the processeses are in sync after updating the cluster
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
@def_function.function @def_function.function
def run_reduce(): def run_reduce():
@ -107,7 +107,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
def worker_step_fn(worker_id, num_dims): def worker_step_fn(worker_id, num_dims):
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
# Make sure the processeses are in sync after updating the cluster # Make sure the processeses are in sync after updating the cluster
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
tensor_shape = [2] * num_dims tensor_shape = [2] * num_dims
def variable_fn(): def variable_fn():

View File

@ -124,10 +124,10 @@ def _get_multi_worker_mirrored_creator(required_gpus):
# collectives may hang if any worker launches collectives before the chief # collectives may hang if any worker launches collectives before the chief
# creates the strategy. # creates the strategy.
try: try:
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
except ValueError: except ValueError:
# If the creator is called in the main process, # If the creator is called in the main process,
# multi_process_runner.barrier() raises ValueError, which is safe to # multi_process_runner.get_barrier() raises ValueError, which is safe to
# ignore. # ignore.
pass pass
return strategy return strategy

View File

@ -204,7 +204,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
if 'Interrupting!' not in str(e): if 'Interrupting!' not in str(e):
raise raise
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
backup_filepath = os.path.join(bar_dir, 'checkpoint') backup_filepath = os.path.join(bar_dir, 'checkpoint')
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
@ -218,7 +218,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
callbacks.BackupAndRestore(backup_dir=bar_dir), callbacks.BackupAndRestore(backup_dir=bar_dir),
AssertCallback() AssertCallback()
]) ])
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
test_obj.assertFalse(file_io.file_exists_v2(backup_filepath)) test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
@ -306,7 +306,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
# The saving_filepath shouldn't exist at the beginning (as it's unique). # The saving_filepath shouldn't exist at the beginning (as it's unique).
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath)) test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
model.fit( model.fit(
x=train_ds, x=train_ds,
@ -314,7 +314,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
steps_per_epoch=steps, steps_per_epoch=steps,
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)]) callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
test_obj.assertTrue(file_io.list_directory_v2(saving_filepath)) test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))

View File

@ -158,7 +158,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
file_io.delete_recursively_v2(os.path.dirname(write_model_path)) file_io.delete_recursively_v2(os.path.dirname(write_model_path))
# Make sure chief finishes saving before non-chief's assertions. # Make sure chief finishes saving before non-chief's assertions.
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
if not file_io.file_exists_v2(model_path): if not file_io.file_exists_v2(model_path):
raise RuntimeError() raise RuntimeError()
@ -179,7 +179,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
file_io.delete_recursively_v2(write_checkpoint_dir) file_io.delete_recursively_v2(write_checkpoint_dir)
# Make sure chief finishes saving before non-chief's assertions. # Make sure chief finishes saving before non-chief's assertions.
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
if not file_io.file_exists_v2(checkpoint_dir): if not file_io.file_exists_v2(checkpoint_dir):
raise RuntimeError() raise RuntimeError()

View File

@ -64,7 +64,7 @@ class CollectiveOpTest(test.TestCase):
except errors.UnavailableError: except errors.UnavailableError:
continue continue
break break
multi_process_runner.barrier().wait() multi_process_runner.get_barrier().wait()
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)