Fixes how the compile-time distribution strategy is used in the keras v2 execution path. Namely, this allows fitting a model compiled with no distribution strategy inside of a different distribution strategy scope, and correctly grabs the active distribution strategy when predicting with models that have never been compiled.
PiperOrigin-RevId: 260038542
This commit is contained in:
parent
2cae1803c1
commit
4ad169057c
@ -34,8 +34,9 @@ class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase):
|
|||||||
saved_model.export_saved_model(model, saved_dir)
|
saved_model.export_saved_model(model, saved_dir)
|
||||||
|
|
||||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||||
output_name):
|
output_name, run_distributed):
|
||||||
restored_keras_model = saved_model.load_from_saved_model(saved_dir)
|
restored_keras_model = saved_model.load_from_saved_model(saved_dir)
|
||||||
|
restored_keras_model._run_distributed = run_distributed
|
||||||
return restored_keras_model.predict(
|
return restored_keras_model.predict(
|
||||||
predict_dataset, steps=test_base.PREDICT_STEPS)
|
predict_dataset, steps=test_base.PREDICT_STEPS)
|
||||||
|
|
||||||
|
@ -34,8 +34,9 @@ class KerasSaveLoadTest(test_base.TestSavedModelBase):
|
|||||||
model.save(saved_dir, save_format='tf')
|
model.save(saved_dir, save_format='tf')
|
||||||
|
|
||||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||||
output_name):
|
output_name, run_distributed):
|
||||||
restored_keras_model = save.load_model(saved_dir)
|
restored_keras_model = save.load_model(saved_dir)
|
||||||
|
restored_keras_model._run_distributed = run_distributed
|
||||||
return restored_keras_model.predict(
|
return restored_keras_model.predict(
|
||||||
predict_dataset, steps=test_base.PREDICT_STEPS)
|
predict_dataset, steps=test_base.PREDICT_STEPS)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
|||||||
keras_saved_model.export_saved_model(model, saved_dir, serving_only=True)
|
keras_saved_model.export_saved_model(model, saved_dir, serving_only=True)
|
||||||
|
|
||||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||||
output_name):
|
output_name, run_distributed):
|
||||||
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
|
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
|
||||||
predict_dataset,
|
predict_dataset,
|
||||||
output_name)
|
output_name)
|
||||||
|
@ -34,7 +34,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
|||||||
saved_model.save(model, saved_dir)
|
saved_model.save(model, saved_dir)
|
||||||
|
|
||||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||||
output_name):
|
output_name, run_distributed):
|
||||||
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
|
return test_base.load_and_run_with_saved_model_api(distribution, saved_dir,
|
||||||
predict_dataset,
|
predict_dataset,
|
||||||
output_name)
|
output_name)
|
||||||
|
@ -118,7 +118,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
|||||||
raise NotImplementedError('must be implemented in descendants')
|
raise NotImplementedError('must be implemented in descendants')
|
||||||
|
|
||||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||||
output_name):
|
output_name, run_distributed):
|
||||||
"""Load the model and run 1 step of predict with it.
|
"""Load the model and run 1 step of predict with it.
|
||||||
|
|
||||||
This method must be implemented by the subclasses.
|
This method must be implemented by the subclasses.
|
||||||
@ -131,6 +131,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
|||||||
cross_replica context.
|
cross_replica context.
|
||||||
output_name: the string representing the name of the output layer of the
|
output_name: the string representing the name of the output layer of the
|
||||||
model.
|
model.
|
||||||
|
run_distributed: Whether to use the v2 execution path for models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError('must be implemented in descendants')
|
raise NotImplementedError('must be implemented in descendants')
|
||||||
@ -172,7 +173,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=distribution,
|
distribution=distribution,
|
||||||
saved_dir=saved_dir,
|
saved_dir=saved_dir,
|
||||||
predict_dataset=predict_dataset,
|
predict_dataset=predict_dataset,
|
||||||
output_name=output_name)
|
output_name=output_name,
|
||||||
|
run_distributed=run_distributed)
|
||||||
|
|
||||||
self.assertAllClose(result_before_save, result_after_save, atol=_TOLERANCE)
|
self.assertAllClose(result_before_save, result_after_save, atol=_TOLERANCE)
|
||||||
|
|
||||||
@ -203,7 +205,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=None,
|
distribution=None,
|
||||||
saved_dir=saved_dir,
|
saved_dir=saved_dir,
|
||||||
predict_dataset=predict_dataset,
|
predict_dataset=predict_dataset,
|
||||||
output_name=output_name)
|
output_name=output_name,
|
||||||
|
run_distributed=run_distributed)
|
||||||
|
|
||||||
self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
|
self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
|
||||||
|
|
||||||
@ -237,6 +240,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
|||||||
distribution=distribution_for_restoring,
|
distribution=distribution_for_restoring,
|
||||||
saved_dir=saved_dir,
|
saved_dir=saved_dir,
|
||||||
predict_dataset=predict_dataset,
|
predict_dataset=predict_dataset,
|
||||||
output_name=output_name)
|
output_name=output_name,
|
||||||
|
run_distributed=run_distributed)
|
||||||
|
|
||||||
self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
|
self.assertAllClose(result_before_save, load_result, atol=_TOLERANCE)
|
||||||
|
@ -1009,12 +1009,12 @@ def _copy_weights_to_original_model(model, mode):
|
|||||||
model.set_weights(updated_weights)
|
model.set_weights(updated_weights)
|
||||||
|
|
||||||
|
|
||||||
def _per_replica_aggregate_batch(batch_outs, model, mode):
|
def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
|
||||||
"""Aggregates the per-replica batch-level outputs from a distributed step."""
|
"""Aggregates the per-replica batch-level outputs from a distributed step."""
|
||||||
if model._distribution_strategy is not None and mode == ModeKeys.PREDICT:
|
if strategy is not None and mode == ModeKeys.PREDICT:
|
||||||
total_batch_outs = []
|
total_batch_outs = []
|
||||||
for i in range(len(model.outputs)):
|
for i in range(len(model.outputs)):
|
||||||
num_replicas = model._distribution_strategy.num_replicas_in_sync
|
num_replicas = strategy.num_replicas_in_sync
|
||||||
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
||||||
total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
|
total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
|
||||||
return total_batch_outs
|
return total_batch_outs
|
||||||
|
@ -146,6 +146,7 @@ class Model(network.Network):
|
|||||||
# initializing _distribution_strategy here since it is possible to call
|
# initializing _distribution_strategy here since it is possible to call
|
||||||
# predict on a model without compiling it.
|
# predict on a model without compiling it.
|
||||||
self._distribution_strategy = None
|
self._distribution_strategy = None
|
||||||
|
self._compile_time_distribution_strategy = None
|
||||||
|
|
||||||
# This flag is used to track if the user is using the deprecated path of
|
# This flag is used to track if the user is using the deprecated path of
|
||||||
# passing distribution strategy to compile rather than creating the model
|
# passing distribution strategy to compile rather than creating the model
|
||||||
@ -161,8 +162,10 @@ class Model(network.Network):
|
|||||||
Returns:
|
Returns:
|
||||||
A flat list of Numpy arrays.
|
A flat list of Numpy arrays.
|
||||||
"""
|
"""
|
||||||
if self._distribution_strategy:
|
strategy = (self._distribution_strategy or
|
||||||
with self._distribution_strategy.scope():
|
self._compile_time_distribution_strategy)
|
||||||
|
if strategy:
|
||||||
|
with strategy.scope():
|
||||||
return super(Model, self).get_weights()
|
return super(Model, self).get_weights()
|
||||||
return super(Model, self).get_weights()
|
return super(Model, self).get_weights()
|
||||||
|
|
||||||
@ -250,6 +253,9 @@ class Model(network.Network):
|
|||||||
# Fallback out of things that aren't supported with v2 loops
|
# Fallback out of things that aren't supported with v2 loops
|
||||||
self._run_distributed = False
|
self._run_distributed = False
|
||||||
|
|
||||||
|
self._compile_time_distribution_strategy = (
|
||||||
|
distribution_strategy_context.get_strategy())
|
||||||
|
|
||||||
if distribute is not None:
|
if distribute is not None:
|
||||||
if tf2.enabled() or self._run_distributed:
|
if tf2.enabled() or self._run_distributed:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -334,7 +334,7 @@ def model_iteration(model,
|
|||||||
|
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
batch_outs = distributed_training_utils._per_replica_aggregate_batch(
|
batch_outs = distributed_training_utils._per_replica_aggregate_batch(
|
||||||
batch_outs, model, mode)
|
model._distribution_strategy, batch_outs, model, mode)
|
||||||
|
|
||||||
# Aggregate results.
|
# Aggregate results.
|
||||||
if step == 0:
|
if step == 0:
|
||||||
|
@ -155,7 +155,7 @@ def run_one_epoch(model,
|
|||||||
batch_outs = [batch_outs]
|
batch_outs = [batch_outs]
|
||||||
if strategy:
|
if strategy:
|
||||||
batch_outs = dist_utils._per_replica_aggregate_batch(
|
batch_outs = dist_utils._per_replica_aggregate_batch(
|
||||||
batch_outs, model, mode)
|
strategy, batch_outs, model, mode)
|
||||||
|
|
||||||
if step == 0:
|
if step == 0:
|
||||||
aggregator.create(batch_outs)
|
aggregator.create(batch_outs)
|
||||||
@ -448,24 +448,12 @@ class Loop(training_utils.TrainingLoop):
|
|||||||
|
|
||||||
def _get_distribution_strategy(model):
|
def _get_distribution_strategy(model):
|
||||||
"""Get the model's distribution strategy."""
|
"""Get the model's distribution strategy."""
|
||||||
if model._distribution_strategy:
|
if model._compile_time_distribution_strategy:
|
||||||
return model._distribution_strategy
|
strategy = model._compile_time_distribution_strategy
|
||||||
else:
|
else:
|
||||||
# Use the default strategy if no strategy was present at compile.
|
# Grab the active strategy if the model was never compiled
|
||||||
# Validate there is no actual strategy scope active at execution
|
# but it is now predicting.
|
||||||
# time.
|
|
||||||
strategy = distribution_strategy_context.get_strategy()
|
strategy = distribution_strategy_context.get_strategy()
|
||||||
if distribution_strategy_context.has_strategy():
|
|
||||||
raise ValueError(
|
|
||||||
'Model was compiled without any active distribution strategy, '
|
|
||||||
'but there is an execution-time distribution '
|
|
||||||
'strategy scope of (%s). '
|
|
||||||
'Try to make sure your code looks similar to the following.\n'
|
|
||||||
'with strategy.scope():\n'
|
|
||||||
' model=_create_model()\n'
|
|
||||||
' model.compile(...)\n'
|
|
||||||
' model.fit(...)'% strategy)
|
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user