MultiProcessRunner: symbol replacement: barrier->get_barrier
PiperOrigin-RevId: 335901962 Change-Id: I55ac18bb4a6ac52ca22ac06c329d02f4fdfe083c
This commit is contained in:
parent
e0ed4b42ee
commit
cb1786e165
@ -1145,11 +1145,12 @@ def run(fn,
|
|||||||
_barrier = None
|
_barrier = None
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
def get_barrier():
|
||||||
if _barrier is None:
|
if _barrier is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'barrier is not defined. It is likely because you are calling barrier()'
|
'barrier is not defined. It is likely because you are calling '
|
||||||
'in the main process. barrier() can only be called in the subprocesses.'
|
'get_barrier() in the main process. get_barrier() can only be called '
|
||||||
|
'in the subprocesses.'
|
||||||
)
|
)
|
||||||
return _barrier
|
return _barrier
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ def fn_that_returns_args_and_kwargs(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def fn_with_barrier():
|
def fn_with_barrier():
|
||||||
return multi_process_runner.barrier()
|
return multi_process_runner.get_barrier()
|
||||||
|
|
||||||
|
|
||||||
def fn_that_returns_pid():
|
def fn_that_returns_pid():
|
||||||
@ -296,7 +296,7 @@ class MultiProcessRunnerTest(test.TestCase):
|
|||||||
|
|
||||||
def test_barrier_called_in_main_process(self):
|
def test_barrier_called_in_main_process(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
multi_process_runner.barrier()
|
multi_process_runner.get_barrier()
|
||||||
|
|
||||||
def test_stdout_available_when_timeout(self):
|
def test_stdout_available_when_timeout(self):
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
|||||||
def worker_step_fn(worker_id):
|
def worker_step_fn(worker_id):
|
||||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
||||||
# Make sure the processeses are in sync after updating the cluster
|
# Make sure the processeses are in sync after updating the cluster
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def run_reduce():
|
def run_reduce():
|
||||||
@ -107,7 +107,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
|||||||
def worker_step_fn(worker_id, num_dims):
|
def worker_step_fn(worker_id, num_dims):
|
||||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
||||||
# Make sure the processeses are in sync after updating the cluster
|
# Make sure the processeses are in sync after updating the cluster
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
tensor_shape = [2] * num_dims
|
tensor_shape = [2] * num_dims
|
||||||
|
|
||||||
def variable_fn():
|
def variable_fn():
|
||||||
|
@ -124,10 +124,10 @@ def _get_multi_worker_mirrored_creator(required_gpus):
|
|||||||
# collectives may hang if any worker launches collectives before the chief
|
# collectives may hang if any worker launches collectives before the chief
|
||||||
# creates the strategy.
|
# creates the strategy.
|
||||||
try:
|
try:
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If the creator is called in the main process,
|
# If the creator is called in the main process,
|
||||||
# multi_process_runner.barrier() raises ValueError, which is safe to
|
# multi_process_runner.get_barrier() raises ValueError, which is safe to
|
||||||
# ignore.
|
# ignore.
|
||||||
pass
|
pass
|
||||||
return strategy
|
return strategy
|
||||||
|
@ -204,7 +204,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
if 'Interrupting!' not in str(e):
|
if 'Interrupting!' not in str(e):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
backup_filepath = os.path.join(bar_dir, 'checkpoint')
|
backup_filepath = os.path.join(bar_dir, 'checkpoint')
|
||||||
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
|
||||||
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||||
@ -218,7 +218,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
callbacks.BackupAndRestore(backup_dir=bar_dir),
|
callbacks.BackupAndRestore(backup_dir=bar_dir),
|
||||||
AssertCallback()
|
AssertCallback()
|
||||||
])
|
])
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
|
||||||
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
|
||||||
|
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
model.fit(
|
model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
@ -314,7 +314,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
steps_per_epoch=steps,
|
steps_per_epoch=steps,
|
||||||
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||||
|
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))
|
test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
|||||||
file_io.delete_recursively_v2(os.path.dirname(write_model_path))
|
file_io.delete_recursively_v2(os.path.dirname(write_model_path))
|
||||||
|
|
||||||
# Make sure chief finishes saving before non-chief's assertions.
|
# Make sure chief finishes saving before non-chief's assertions.
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
if not file_io.file_exists_v2(model_path):
|
if not file_io.file_exists_v2(model_path):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
@ -179,7 +179,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
|||||||
file_io.delete_recursively_v2(write_checkpoint_dir)
|
file_io.delete_recursively_v2(write_checkpoint_dir)
|
||||||
|
|
||||||
# Make sure chief finishes saving before non-chief's assertions.
|
# Make sure chief finishes saving before non-chief's assertions.
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
if not file_io.file_exists_v2(checkpoint_dir):
|
if not file_io.file_exists_v2(checkpoint_dir):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
@ -64,7 +64,7 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
except errors.UnavailableError:
|
except errors.UnavailableError:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
multi_process_runner.barrier().wait()
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||||
|
Loading…
Reference in New Issue
Block a user