diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index 7e890934aa7..c99b6db8f4d 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -94,6 +94,8 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): x=train_ds, epochs=num_epoch, steps_per_epoch=steps, + validation_data=train_ds, + validation_steps=steps, callbacks=[ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=save_weights_only) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 76020aba8d8..fe91a6c1ab0 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -21,6 +21,7 @@ from __future__ import print_function import copy from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import distribute_coordinator_context as dc_context from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop @@ -63,6 +64,10 @@ def enable_multi_worker(method): if not self._in_multi_worker_mode(): # pylint: disable=protected-access return method(self, *args, **kwargs) + # Running inside `run_distribute_coordinator` already. + if dc_context.get_current_worker_context(): + return method(self, *args, **kwargs) + return dc.run_distribute_coordinator( lambda _: method(self, *args, **kwargs), self.distribute_strategy,