Fix MultiWorkerMirroredStrategy validation in Model.fit
PiperOrigin-RevId: 299150128 Change-Id: Ie0ef99dcbd1afc91ad1a0e19c56638c5e48a7865
This commit is contained in:
parent
d5504fbc78
commit
a45094173c
@ -94,6 +94,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
validation_data=train_ds,
|
||||
validation_steps=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath, save_weights_only=save_weights_only)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import copy
|
||||
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -63,6 +64,10 @@ def enable_multi_worker(method):
|
||||
if not self._in_multi_worker_mode(): # pylint: disable=protected-access
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
# Running inside `run_distribute_coordinator` already.
|
||||
if dc_context.get_current_worker_context():
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return dc.run_distribute_coordinator(
|
||||
lambda _: method(self, *args, **kwargs),
|
||||
self.distribute_strategy,
|
||||
|
Loading…
Reference in New Issue
Block a user