Fix MultiWorkerMirroredStrategy validation in Model.fit

PiperOrigin-RevId: 299150128
Change-Id: Ie0ef99dcbd1afc91ad1a0e19c56638c5e48a7865
This commit is contained in:
Thomas O'Malley 2020-03-05 11:29:05 -08:00 committed by TensorFlower Gardener
parent d5504fbc78
commit a45094173c
2 changed files with 7 additions and 0 deletions

View File

@ -94,6 +94,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
validation_data=train_ds,
validation_steps=steps,
callbacks=[
callbacks.ModelCheckpoint(
filepath=saving_filepath, save_weights_only=save_weights_only)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import copy
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import backprop
@ -63,6 +64,10 @@ def enable_multi_worker(method):
if not self._in_multi_worker_mode(): # pylint: disable=protected-access
return method(self, *args, **kwargs)
# Running inside `run_distribute_coordinator` already.
if dc_context.get_current_worker_context():
return method(self, *args, **kwargs)
return dc.run_distribute_coordinator(
lambda _: method(self, *args, **kwargs),
self.distribute_strategy,