diff --git a/tensorflow/python/distribute/keras_experimental_saved_model_test.py b/tensorflow/python/distribute/keras_experimental_saved_model_test.py index 0a0a57ffe33..92d9f14a6ed 100644 --- a/tensorflow/python/distribute/keras_experimental_saved_model_test.py +++ b/tensorflow/python/distribute/keras_experimental_saved_model_test.py @@ -34,8 +34,9 @@ class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase): saved_model.export_saved_model(model, saved_dir) def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name): + output_name, run_distributed): restored_keras_model = saved_model.load_from_saved_model(saved_dir) + restored_keras_model._run_distributed = run_distributed return restored_keras_model.predict( predict_dataset, steps=test_base.PREDICT_STEPS) diff --git a/tensorflow/python/distribute/keras_save_load_test.py b/tensorflow/python/distribute/keras_save_load_test.py index fcb4941688d..2ff856ff6b0 100644 --- a/tensorflow/python/distribute/keras_save_load_test.py +++ b/tensorflow/python/distribute/keras_save_load_test.py @@ -34,8 +34,9 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase): model.save(saved_dir, save_format='tf') def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name): + output_name, run_distributed): restored_keras_model = save.load_model(saved_dir) + restored_keras_model._run_distributed = run_distributed return restored_keras_model.predict( predict_dataset, steps=test_base.PREDICT_STEPS) diff --git a/tensorflow/python/distribute/saved_model_mixed_api_test.py b/tensorflow/python/distribute/saved_model_mixed_api_test.py index 834cfbbabeb..dc2a40568b9 100644 --- a/tensorflow/python/distribute/saved_model_mixed_api_test.py +++ b/tensorflow/python/distribute/saved_model_mixed_api_test.py @@ -42,7 +42,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase): keras_saved_model.export_saved_model(model, saved_dir, serving_only=True) def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name): + output_name, run_distributed): return test_base.load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset, output_name) diff --git a/tensorflow/python/distribute/saved_model_save_load_test.py b/tensorflow/python/distribute/saved_model_save_load_test.py index 6c0b2463de4..39e1d8a2b98 100644 --- a/tensorflow/python/distribute/saved_model_save_load_test.py +++ b/tensorflow/python/distribute/saved_model_save_load_test.py @@ -34,7 +34,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase): saved_model.save(model, saved_dir) def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name): + output_name, run_distributed): return test_base.load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset, output_name) diff --git a/tensorflow/python/distribute/saved_model_test_base.py b/tensorflow/python/distribute/saved_model_test_base.py index 31b84b13b88..6326aafa5bc 100644 --- a/tensorflow/python/distribute/saved_model_test_base.py +++ b/tensorflow/python/distribute/saved_model_test_base.py @@ -118,7 +118,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): raise NotImplementedError('must be implemented in descendants') def _load_and_run_model(self, distribution, saved_dir, predict_dataset, - output_name): + output_name, run_distributed): """Load the model and run 1 step of predict with it. This method must be implemented by the subclasses. @@ -131,6 +131,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): cross_replica context. output_name: the string representing the name of the output layer of the model. + run_distributed: Whether to use the v2 execution path for models. """ raise NotImplementedError('must be implemented in descendants') @@ -172,7 +173,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): distribution=distribution, saved_dir=saved_dir, predict_dataset=predict_dataset, - output_name=output_name) + output_name=output_name, + run_distributed=run_distributed) self.assertAllClose(result_before_save, result_after_save, atol=_TOLERANCE) @@ -203,7 +205,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): distribution=None, saved_dir=saved_dir, predict_dataset=predict_dataset, - output_name=output_name) + output_name=output_name, + run_distributed=run_distributed) self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE) @@ -237,6 +240,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase): distribution=distribution_for_restoring, saved_dir=saved_dir, predict_dataset=predict_dataset, - output_name=output_name) + output_name=output_name, + run_distributed=run_distributed) self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE) diff --git a/tensorflow/python/keras/distribute/distributed_training_utils.py b/tensorflow/python/keras/distribute/distributed_training_utils.py index 28489de3fc1..00d182d2368 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils.py @@ -1009,12 +1009,12 @@ def _copy_weights_to_original_model(model, mode): model.set_weights(updated_weights) -def _per_replica_aggregate_batch(batch_outs, model, mode): +def _per_replica_aggregate_batch(strategy, batch_outs, model, mode): """Aggregates the per-replica batch-level outputs from a distributed step.""" - if model._distribution_strategy is not None and mode == ModeKeys.PREDICT: + if strategy is not None and mode == ModeKeys.PREDICT: total_batch_outs = [] for i in range(len(model.outputs)): - num_replicas = model._distribution_strategy.num_replicas_in_sync + num_replicas = strategy.num_replicas_in_sync nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] total_batch_outs.append(np.concatenate(nest.flatten(nested_outs))) return total_batch_outs diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index cf8bd1bc22d..dda30461d95 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -146,6 +146,7 @@ class Model(network.Network): # initializing _distribution_strategy here since it is possible to call # predict on a model without compiling it. self._distribution_strategy = None + self._compile_time_distribution_strategy = None # This flag is used to track if the user is using the deprecated path of # passing distribution strategy to compile rather than creating the model @@ -161,8 +162,10 @@ class Model(network.Network): Returns: A flat list of Numpy arrays. """ - if self._distribution_strategy: - with self._distribution_strategy.scope(): + strategy = (self._distribution_strategy or + self._compile_time_distribution_strategy) + if strategy: + with strategy.scope(): return super(Model, self).get_weights() return super(Model, self).get_weights() @@ -250,6 +253,9 @@ class Model(network.Network): # Fallback out of things that aren't supported with v2 loops self._run_distributed = False + self._compile_time_distribution_strategy = ( + distribution_strategy_context.get_strategy()) + if distribute is not None: if tf2.enabled() or self._run_distributed: raise ValueError( diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index 8780739398d..1f9bce7372a 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -334,7 +334,7 @@ def model_iteration(model, if model._distribution_strategy: batch_outs = distributed_training_utils._per_replica_aggregate_batch( - batch_outs, model, mode) + model._distribution_strategy, batch_outs, model, mode) # Aggregate results. if step == 0: diff --git a/tensorflow/python/keras/engine/training_v2.py b/tensorflow/python/keras/engine/training_v2.py index 2371d20684b..8f6d4abfec6 100644 --- a/tensorflow/python/keras/engine/training_v2.py +++ b/tensorflow/python/keras/engine/training_v2.py @@ -155,7 +155,7 @@ def run_one_epoch(model, batch_outs = [batch_outs] if strategy: batch_outs = dist_utils._per_replica_aggregate_batch( - batch_outs, model, mode) + strategy, batch_outs, model, mode) if step == 0: aggregator.create(batch_outs) @@ -448,25 +448,13 @@ class Loop(training_utils.TrainingLoop): def _get_distribution_strategy(model): """Get the model's distribution strategy.""" - if model._distribution_strategy: - return model._distribution_strategy + if model._compile_time_distribution_strategy: + strategy = model._compile_time_distribution_strategy else: - # Use the default strategy if no strategy was present at compile. - # Validate there is no actual strategy scope active at execution - # time. + # Grab the active strategy if the model was never compiled + # but it is now predicting. strategy = distribution_strategy_context.get_strategy() - if distribution_strategy_context.has_strategy(): - raise ValueError( - 'Model was compiled without any active distribution strategy, ' - 'but there is an execution-time distribution ' - 'strategy scope of (%s). ' - 'Try to make sure your code looks similar to the following.\n' - 'with strategy.scope():\n' - ' model=_create_model()\n' - ' model.compile(...)\n' - ' model.fit(...)'% strategy) - - return strategy + return strategy def _process_training_inputs(model, x, y, batch_size=None,