Fix MultiWorkerMirroredStrategy validation in Model.fit
PiperOrigin-RevId: 299150128 Change-Id: Ie0ef99dcbd1afc91ad1a0e19c56638c5e48a7865
This commit is contained in:
parent
d5504fbc78
commit
a45094173c
@ -94,6 +94,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
x=train_ds,
|
x=train_ds,
|
||||||
epochs=num_epoch,
|
epochs=num_epoch,
|
||||||
steps_per_epoch=steps,
|
steps_per_epoch=steps,
|
||||||
|
validation_data=train_ds,
|
||||||
|
validation_steps=steps,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
callbacks.ModelCheckpoint(
|
callbacks.ModelCheckpoint(
|
||||||
filepath=saving_filepath, save_weights_only=save_weights_only)
|
filepath=saving_filepath, save_weights_only=save_weights_only)
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||||
|
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.distribute import values as ds_values
|
from tensorflow.python.distribute import values as ds_values
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
@ -63,6 +64,10 @@ def enable_multi_worker(method):
|
|||||||
if not self._in_multi_worker_mode(): # pylint: disable=protected-access
|
if not self._in_multi_worker_mode(): # pylint: disable=protected-access
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# Running inside `run_distribute_coordinator` already.
|
||||||
|
if dc_context.get_current_worker_context():
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return dc.run_distribute_coordinator(
|
return dc.run_distribute_coordinator(
|
||||||
lambda _: method(self, *args, **kwargs),
|
lambda _: method(self, *args, **kwargs),
|
||||||
self.distribute_strategy,
|
self.distribute_strategy,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user