MultiProcessRunner: symbol replacement: barrier->get_barrier

PiperOrigin-RevId: 335901962
Change-Id: I55ac18bb4a6ac52ca22ac06c329d02f4fdfe083c
This commit is contained in:
Rick Chao 2020-10-07 10:41:31 -07:00 committed by TensorFlower Gardener
parent e0ed4b42ee
commit cb1786e165
7 changed files with 17 additions and 16 deletions

View File

@ -1145,11 +1145,12 @@ def run(fn,
_barrier = None
def barrier():
def get_barrier():
if _barrier is None:
raise ValueError(
'barrier is not defined. It is likely because you are calling barrier()'
'in the main process. barrier() can only be called in the subprocesses.'
'barrier is not defined. It is likely because you are calling '
'get_barrier() in the main process. get_barrier() can only be called '
'in the subprocesses.'
)
return _barrier

View File

@ -54,7 +54,7 @@ def fn_that_returns_args_and_kwargs(*args, **kwargs):
def fn_with_barrier():
return multi_process_runner.barrier()
return multi_process_runner.get_barrier()
def fn_that_returns_pid():
@ -296,7 +296,7 @@ class MultiProcessRunnerTest(test.TestCase):
def test_barrier_called_in_main_process(self):
with self.assertRaises(ValueError):
multi_process_runner.barrier()
multi_process_runner.get_barrier()
def test_stdout_available_when_timeout(self):

View File

@ -74,7 +74,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
def worker_step_fn(worker_id):
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
# Make sure the processeses are in sync after updating the cluster
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
@def_function.function
def run_reduce():
@ -107,7 +107,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
def worker_step_fn(worker_id, num_dims):
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
# Make sure the processeses are in sync after updating the cluster
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
tensor_shape = [2] * num_dims
def variable_fn():

View File

@ -124,10 +124,10 @@ def _get_multi_worker_mirrored_creator(required_gpus):
# collectives may hang if any worker launches collectives before the chief
# creates the strategy.
try:
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
except ValueError:
# If the creator is called in the main process,
# multi_process_runner.barrier() raises ValueError, which is safe to
# multi_process_runner.get_barrier() raises ValueError, which is safe to
# ignore.
pass
return strategy

View File

@ -204,7 +204,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
if 'Interrupting!' not in str(e):
raise
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
backup_filepath = os.path.join(bar_dir, 'checkpoint')
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
@ -218,7 +218,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
callbacks.BackupAndRestore(backup_dir=bar_dir),
AssertCallback()
])
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
@ -306,7 +306,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
# The saving_filepath shouldn't exist at the beginning (as it's unique).
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
model.fit(
x=train_ds,
@ -314,7 +314,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
steps_per_epoch=steps,
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))

View File

@ -158,7 +158,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
file_io.delete_recursively_v2(os.path.dirname(write_model_path))
# Make sure chief finishes saving before non-chief's assertions.
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
if not file_io.file_exists_v2(model_path):
raise RuntimeError()
@ -179,7 +179,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
file_io.delete_recursively_v2(write_checkpoint_dir)
# Make sure chief finishes saving before non-chief's assertions.
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
if not file_io.file_exists_v2(checkpoint_dir):
raise RuntimeError()

View File

@ -64,7 +64,7 @@ class CollectiveOpTest(test.TestCase):
except errors.UnavailableError:
continue
break
multi_process_runner.barrier().wait()
multi_process_runner.get_barrier().wait()
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)