Fixes how the compile-time distribution strategy is used in the keras v2 execution path. Namely, this allows fitting a model compiled with no distribution strategy inside of a different distribution strategy scope, and correctly grabs the active distribution strategy when predicting with models that have never been compiled.
PiperOrigin-RevId: 260038542
This commit is contained in:
parent
2cae1803c1
commit
4ad169057c
@ -34,8 +34,9 @@ class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase):
|
||||
saved_model.export_saved_model(model, saved_dir)
|
||||
|
||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||
output_name):
|
||||
output_name, run_distributed):
|
||||
restored_keras_model = saved_model.load_from_saved_model(saved_dir)
|
||||
restored_keras_model._run_distributed = run_distributed
|
||||
return restored_keras_model.predict(
|
||||
predict_dataset, steps=test_base.PREDICT_STEPS)
|
||||
|
||||
|
||||
@ -34,8 +34,9 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase):
|
||||
model.save(saved_dir, save_format='tf')
|
||||
|
||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||
output_name):
|
||||
output_name, run_distributed):
|
||||
restored_keras_model = save.load_model(saved_dir)
|
||||
restored_keras_model._run_distributed = run_distributed
|
||||
return restored_keras_model.predict(
|
||||
predict_dataset, steps=test_base.PREDICT_STEPS)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
||||
keras_saved_model.export_saved_model(model, saved_dir, serving_only=True)
|
||||
|
||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||
output_name):
|
||||
output_name, run_distributed):
|
||||
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
|
||||
predict_dataset,
|
||||
output_name)
|
||||
|
||||
@ -34,7 +34,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
||||
saved_model.save(model, saved_dir)
|
||||
|
||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||
output_name):
|
||||
output_name, run_distributed):
|
||||
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
|
||||
predict_dataset,
|
||||
output_name)
|
||||
|
||||
@ -118,7 +118,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
raise NotImplementedError('must be implemented in descendants')
|
||||
|
||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||
output_name):
|
||||
output_name, run_distributed):
|
||||
"""Load the model and run 1 step of predict with it.
|
||||
|
||||
This method must be implemented by the subclasses.
|
||||
@ -131,6 +131,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
cross_replica context.
|
||||
output_name: the string representing the name of the output layer of the
|
||||
model.
|
||||
run_distributed: Whether to use the v2 execution path for models.
|
||||
"""
|
||||
|
||||
raise NotImplementedError('must be implemented in descendants')
|
||||
@ -172,7 +173,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
distribution=distribution,
|
||||
saved_dir=saved_dir,
|
||||
predict_dataset=predict_dataset,
|
||||
output_name=output_name)
|
||||
output_name=output_name,
|
||||
run_distributed=run_distributed)
|
||||
|
||||
self.assertAllClose(result_before_save, result_after_save, atol=_TOLERANCE)
|
||||
|
||||
@ -203,7 +205,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
distribution=None,
|
||||
saved_dir=saved_dir,
|
||||
predict_dataset=predict_dataset,
|
||||
output_name=output_name)
|
||||
output_name=output_name,
|
||||
run_distributed=run_distributed)
|
||||
|
||||
self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
|
||||
|
||||
@ -237,6 +240,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
distribution=distribution_for_restoring,
|
||||
saved_dir=saved_dir,
|
||||
predict_dataset=predict_dataset,
|
||||
output_name=output_name)
|
||||
output_name=output_name,
|
||||
run_distributed=run_distributed)
|
||||
|
||||
self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
|
||||
|
||||
@ -1009,12 +1009,12 @@ def _copy_weights_to_original_model(model, mode):
|
||||
model.set_weights(updated_weights)
|
||||
|
||||
|
||||
def _per_replica_aggregate_batch(batch_outs, model, mode):
|
||||
def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
|
||||
"""Aggregates the per-replica batch-level outputs from a distributed step."""
|
||||
if model._distribution_strategy is not None and mode == ModeKeys.PREDICT:
|
||||
if strategy is not None and mode == ModeKeys.PREDICT:
|
||||
total_batch_outs = []
|
||||
for i in range(len(model.outputs)):
|
||||
num_replicas = model._distribution_strategy.num_replicas_in_sync
|
||||
num_replicas = strategy.num_replicas_in_sync
|
||||
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
||||
total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
|
||||
return total_batch_outs
|
||||
|
||||
@ -146,6 +146,7 @@ class Model(network.Network):
|
||||
# initializing _distribution_strategy here since it is possible to call
|
||||
# predict on a model without compiling it.
|
||||
self._distribution_strategy = None
|
||||
self._compile_time_distribution_strategy = None
|
||||
|
||||
# This flag is used to track if the user is using the deprecated path of
|
||||
# passing distribution strategy to compile rather than creating the model
|
||||
@ -161,8 +162,10 @@ class Model(network.Network):
|
||||
Returns:
|
||||
A flat list of Numpy arrays.
|
||||
"""
|
||||
if self._distribution_strategy:
|
||||
with self._distribution_strategy.scope():
|
||||
strategy = (self._distribution_strategy or
|
||||
self._compile_time_distribution_strategy)
|
||||
if strategy:
|
||||
with strategy.scope():
|
||||
return super(Model, self).get_weights()
|
||||
return super(Model, self).get_weights()
|
||||
|
||||
@ -250,6 +253,9 @@ class Model(network.Network):
|
||||
# Fallback out of things that aren't supported with v2 loops
|
||||
self._run_distributed = False
|
||||
|
||||
self._compile_time_distribution_strategy = (
|
||||
distribution_strategy_context.get_strategy())
|
||||
|
||||
if distribute is not None:
|
||||
if tf2.enabled() or self._run_distributed:
|
||||
raise ValueError(
|
||||
|
||||
@ -334,7 +334,7 @@ def model_iteration(model,
|
||||
|
||||
if model._distribution_strategy:
|
||||
batch_outs = distributed_training_utils._per_replica_aggregate_batch(
|
||||
batch_outs, model, mode)
|
||||
model._distribution_strategy, batch_outs, model, mode)
|
||||
|
||||
# Aggregate results.
|
||||
if step == 0:
|
||||
|
||||
@ -155,7 +155,7 @@ def run_one_epoch(model,
|
||||
batch_outs = [batch_outs]
|
||||
if strategy:
|
||||
batch_outs = dist_utils._per_replica_aggregate_batch(
|
||||
batch_outs, model, mode)
|
||||
strategy, batch_outs, model, mode)
|
||||
|
||||
if step == 0:
|
||||
aggregator.create(batch_outs)
|
||||
@ -448,25 +448,13 @@ class Loop(training_utils.TrainingLoop):
|
||||
|
||||
def _get_distribution_strategy(model):
|
||||
"""Get the model's distribution strategy."""
|
||||
if model._distribution_strategy:
|
||||
return model._distribution_strategy
|
||||
if model._compile_time_distribution_strategy:
|
||||
strategy = model._compile_time_distribution_strategy
|
||||
else:
|
||||
# Use the default strategy if no strategy was present at compile.
|
||||
# Validate there is no actual strategy scope active at execution
|
||||
# time.
|
||||
# Grab the active strategy if the model was never compiled
|
||||
# but it is now predicting.
|
||||
strategy = distribution_strategy_context.get_strategy()
|
||||
if distribution_strategy_context.has_strategy():
|
||||
raise ValueError(
|
||||
'Model was compiled without any active distribution strategy, '
|
||||
'but there is an execution-time distribution '
|
||||
'strategy scope of (%s). '
|
||||
'Try to make sure your code looks similar to the following.\n'
|
||||
'with strategy.scope():\n'
|
||||
' model=_create_model()\n'
|
||||
' model.compile(...)\n'
|
||||
' model.fit(...)'% strategy)
|
||||
|
||||
return strategy
|
||||
return strategy
|
||||
|
||||
|
||||
def _process_training_inputs(model, x, y, batch_size=None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user