MultiProcessRunner: symbol replacement: barrier->get_barrier
PiperOrigin-RevId: 335901962 Change-Id: I55ac18bb4a6ac52ca22ac06c329d02f4fdfe083c
This commit is contained in:
parent
e0ed4b42ee
commit
cb1786e165
@ -1145,11 +1145,12 @@ def run(fn,
|
||||
_barrier = None
|
||||
|
||||
|
||||
def barrier():
|
||||
def get_barrier():
|
||||
if _barrier is None:
|
||||
raise ValueError(
|
||||
'barrier is not defined. It is likely because you are calling barrier()'
|
||||
'in the main process. barrier() can only be called in the subprocesses.'
|
||||
'barrier is not defined. It is likely because you are calling '
|
||||
'get_barrier() in the main process. get_barrier() can only be called '
|
||||
'in the subprocesses.'
|
||||
)
|
||||
return _barrier
|
||||
|
||||
|
@ -54,7 +54,7 @@ def fn_that_returns_args_and_kwargs(*args, **kwargs):
|
||||
|
||||
|
||||
def fn_with_barrier():
|
||||
return multi_process_runner.barrier()
|
||||
return multi_process_runner.get_barrier()
|
||||
|
||||
|
||||
def fn_that_returns_pid():
|
||||
@ -296,7 +296,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
||||
|
||||
def test_barrier_called_in_main_process(self):
|
||||
with self.assertRaises(ValueError):
|
||||
multi_process_runner.barrier()
|
||||
multi_process_runner.get_barrier()
|
||||
|
||||
def test_stdout_available_when_timeout(self):
|
||||
|
||||
|
@ -74,7 +74,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||
def worker_step_fn(worker_id):
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
||||
# Make sure the processeses are in sync after updating the cluster
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
@def_function.function
|
||||
def run_reduce():
|
||||
@ -107,7 +107,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||
def worker_step_fn(worker_id, num_dims):
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
||||
# Make sure the processeses are in sync after updating the cluster
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
tensor_shape = [2] * num_dims
|
||||
|
||||
def variable_fn():
|
||||
|
@ -124,10 +124,10 @@ def _get_multi_worker_mirrored_creator(required_gpus):
|
||||
# collectives may hang if any worker launches collectives before the chief
|
||||
# creates the strategy.
|
||||
try:
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
except ValueError:
|
||||
# If the creator is called in the main process,
|
||||
# multi_process_runner.barrier() raises ValueError, which is safe to
|
||||
# multi_process_runner.get_barrier() raises ValueError, which is safe to
|
||||
# ignore.
|
||||
pass
|
||||
return strategy
|
||||
|
@ -204,7 +204,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
if 'Interrupting!' not in str(e):
|
||||
raise
|
||||
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
backup_filepath = os.path.join(bar_dir, 'checkpoint')
|
||||
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
|
||||
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||
@ -218,7 +218,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
callbacks.BackupAndRestore(backup_dir=bar_dir),
|
||||
AssertCallback()
|
||||
])
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
|
||||
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||
|
||||
@ -306,7 +306,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
@ -314,7 +314,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))
|
||||
|
||||
|
@ -158,7 +158,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
||||
file_io.delete_recursively_v2(os.path.dirname(write_model_path))
|
||||
|
||||
# Make sure chief finishes saving before non-chief's assertions.
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
if not file_io.file_exists_v2(model_path):
|
||||
raise RuntimeError()
|
||||
@ -179,7 +179,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
||||
file_io.delete_recursively_v2(write_checkpoint_dir)
|
||||
|
||||
# Make sure chief finishes saving before non-chief's assertions.
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
if not file_io.file_exists_v2(checkpoint_dir):
|
||||
raise RuntimeError()
|
||||
|
@ -64,7 +64,7 @@ class CollectiveOpTest(test.TestCase):
|
||||
except errors.UnavailableError:
|
||||
continue
|
||||
break
|
||||
multi_process_runner.barrier().wait()
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||
|
Loading…
Reference in New Issue
Block a user