tf.saved_model: Re-create concrete functions at saving time - this ensures that if the cache key has changed, the function will be traced again. For example, when the model is run with distribution strategy, but saved without it, we want to the saved version to trace again without strategy.
PiperOrigin-RevId: 261074074
This commit is contained in:
parent
415771767b
commit
91d7124c3c
tensorflow/python
distribute
eager
keras
@ -22,9 +22,12 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute.model_collection import model_collection_base
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
_BATCH_SIZE = 10
|
||||
|
||||
@ -131,3 +134,27 @@ class SimpleSubclassModel(model_collection_base.ModelAndInput):
|
||||
|
||||
def get_batch_size(self):
|
||||
return _BATCH_SIZE
|
||||
|
||||
|
||||
class _SimpleModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.v = variables.Variable(3.0)
|
||||
|
||||
@def_function.function
|
||||
def __call__(self, x):
|
||||
return self.v * x
|
||||
|
||||
|
||||
class SimpleTFModuleModel(model_collection_base.ModelAndInput):
|
||||
"""A simple model based on tf.Module and its data."""
|
||||
|
||||
def get_model(self, **kwargs):
|
||||
model = _SimpleModule()
|
||||
return model, 'foo'
|
||||
|
||||
def get_data(self):
|
||||
return _get_data_for_simple_models()
|
||||
|
||||
def get_batch_size(self):
|
||||
return _BATCH_SIZE
|
||||
|
@ -29,3 +29,6 @@ simple_sequential_model = combinations.NamedObject(
|
||||
|
||||
simple_subclass_model = combinations.NamedObject(
|
||||
"SimpleSubclassModel", simple_models.SimpleSubclassModel())
|
||||
|
||||
simple_tfmodule_model = combinations.NamedObject(
|
||||
"SimpleTFModuleModel", simple_models.SimpleTFModuleModel())
|
||||
|
@ -21,14 +21,16 @@ from __future__ import print_function
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import saved_model_test_base as test_base
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
|
||||
|
||||
class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
||||
class SavedModelKerasModelTest(test_base.TestSavedModelBase):
|
||||
|
||||
def setUp(self):
|
||||
self._root_dir = 'saved_model_save_load'
|
||||
super(SavedModelSaveAndLoadTest, self).setUp()
|
||||
super(SavedModelKerasModelTest, self).setUp()
|
||||
|
||||
def _save_model(self, model, saved_dir):
|
||||
saved_model.save(model, saved_dir)
|
||||
@ -53,6 +55,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
||||
distribution, save_in_scope,
|
||||
experimental_run_tf_function):
|
||||
if save_in_scope:
|
||||
# TODO(b/134703272): Unskip this test when fixed.
|
||||
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
|
||||
'supported.'))
|
||||
self.run_test_save_strategy_restore_no_strategy(
|
||||
@ -68,6 +71,80 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
|
||||
save_in_scope,
|
||||
experimental_run_tf_function):
|
||||
if save_in_scope:
|
||||
# TODO(b/134703272): Unskip this test when fixed.
|
||||
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
|
||||
'supported.'))
|
||||
self.run_test_save_strategy_restore_strategy(model_and_input,
|
||||
distribution_for_saving,
|
||||
distribution_for_restoring,
|
||||
save_in_scope,
|
||||
experimental_run_tf_function)
|
||||
|
||||
|
||||
class SavedModelTFModuleTest(test_base.TestSavedModelBase):
|
||||
|
||||
def setUp(self):
|
||||
self._root_dir = 'saved_model_save_load'
|
||||
super(SavedModelTFModuleTest, self).setUp()
|
||||
|
||||
def _train_model(self, model, x_train, y_train, batch_size):
|
||||
pass
|
||||
|
||||
def _predict_with_model(self, distribution, model, predict_dataset):
|
||||
if distribution:
|
||||
dist_predict_dataset = distribution.experimental_distribute_dataset(
|
||||
predict_dataset)
|
||||
per_replica_predict_data = next(iter(dist_predict_dataset))
|
||||
result = distribution.experimental_run_v2(
|
||||
model, args=(per_replica_predict_data,))
|
||||
# Convert the per_replica value to a list, then concatenate them
|
||||
reduced = distribution.experimental_local_results(result)
|
||||
concat = array_ops.concat(reduced, 0)
|
||||
return concat
|
||||
else:
|
||||
return model(next(iter(predict_dataset)))
|
||||
|
||||
def _save_model(self, model, saved_dir):
|
||||
call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None))
|
||||
saved_model.save(model, saved_dir, signatures=call)
|
||||
|
||||
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
|
||||
output_name, experimental_run_tf_function):
|
||||
del output_name, experimental_run_tf_function
|
||||
model = saved_model.load(saved_dir)
|
||||
return self._predict_with_model(distribution, model, predict_dataset)
|
||||
|
||||
@combinations.generate(test_base.tfmodule_models_with_strategies())
|
||||
def test_save_no_strategy_restore_strategy(self, model_and_input,
|
||||
distribution,
|
||||
experimental_run_tf_function):
|
||||
self.run_test_save_no_strategy_restore_strategy(
|
||||
model_and_input, distribution, experimental_run_tf_function)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.tfmodule_models_with_strategies(),
|
||||
combinations.combine(save_in_scope=[True, False])))
|
||||
def test_save_strategy_restore_no_strategy(
|
||||
self, model_and_input, distribution, save_in_scope,
|
||||
experimental_run_tf_function):
|
||||
if save_in_scope:
|
||||
# TODO(b/134703272): Unskip this test when fixed.
|
||||
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
|
||||
'supported.'))
|
||||
self.run_test_save_strategy_restore_no_strategy(
|
||||
model_and_input, distribution, save_in_scope,
|
||||
experimental_run_tf_function)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.tfmodule_models_with_strategy_pairs(),
|
||||
combinations.combine(save_in_scope=[True, False])))
|
||||
def test_save_strategy_restore_strategy(self, model_and_input,
|
||||
distribution_for_saving,
|
||||
distribution_for_restoring,
|
||||
save_in_scope,
|
||||
experimental_run_tf_function):
|
||||
if save_in_scope:
|
||||
# TODO(b/134703272): Unskip this test when fixed.
|
||||
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
|
||||
'supported.'))
|
||||
self.run_test_save_strategy_restore_strategy(model_and_input,
|
||||
|
@ -75,6 +75,23 @@ def simple_models_with_strategy_pairs():
|
||||
experimental_run_tf_function=[True, False])
|
||||
|
||||
|
||||
def tfmodule_models_with_strategies():
|
||||
return combinations.combine(
|
||||
model_and_input=[model_combinations.simple_tfmodule_model],
|
||||
distribution=strategies_minus_tpu,
|
||||
mode=['eager'],
|
||||
experimental_run_tf_function=[True])
|
||||
|
||||
|
||||
def tfmodule_models_with_strategy_pairs():
|
||||
return combinations.combine(
|
||||
model_and_input=[model_combinations.simple_tfmodule_model],
|
||||
distribution_for_saving=strategies_minus_tpu,
|
||||
distribution_for_restoring=strategies_minus_tpu,
|
||||
mode=['eager'],
|
||||
experimental_run_tf_function=[True])
|
||||
|
||||
|
||||
def load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset,
|
||||
output_name):
|
||||
"""Loads a saved_model using tf.saved_model API, and runs it."""
|
||||
@ -146,6 +163,9 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
# Train the model for 1 epoch
|
||||
model.fit(x=training_dataset, epochs=1, steps_per_epoch=100)
|
||||
|
||||
def _predict_with_model(self, distribution, model, predict_dataset):
|
||||
return model.predict(predict_dataset, steps=PREDICT_STEPS)
|
||||
|
||||
def _get_predict_dataset(self, x_predict, batch_size):
|
||||
predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
|
||||
predict_dataset = predict_dataset.repeat()
|
||||
@ -163,10 +183,10 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
experimental_run_tf_function=experimental_run_tf_function)
|
||||
x_train, y_train, x_predict = model_and_input.get_data()
|
||||
batch_size = model_and_input.get_batch_size()
|
||||
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
|
||||
|
||||
self._train_model(model, x_train, y_train, batch_size)
|
||||
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
|
||||
result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
|
||||
result_before_save = self._predict_with_model(None, model, predict_dataset)
|
||||
|
||||
self._save_model(model, saved_dir)
|
||||
|
||||
@ -195,7 +215,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self._train_model(model, x_train, y_train, batch_size)
|
||||
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
|
||||
result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
|
||||
result_before_save = self._predict_with_model(
|
||||
distribution, model, predict_dataset)
|
||||
|
||||
if save_in_scope:
|
||||
with distribution.scope():
|
||||
@ -229,7 +250,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self._train_model(model, x_train, y_train, batch_size)
|
||||
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
|
||||
result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
|
||||
result_before_save = self._predict_with_model(
|
||||
distribution_for_saving, model, predict_dataset)
|
||||
|
||||
if save_in_scope:
|
||||
with distribution_for_saving.scope():
|
||||
|
@ -603,13 +603,7 @@ class Function(object):
|
||||
concrete_functions.extend(
|
||||
self._stateless_fn._function_cache.all_values())
|
||||
# pylint: enable=protected-access
|
||||
deduplicated_concrete_functions = []
|
||||
seen_signatures = []
|
||||
# We are using a list so that:
|
||||
# - the returned collection is deterministic, and
|
||||
# - we can use a custom equality operator (is_same_structure).
|
||||
# This is run only at serialization time on likely very small inputs so we
|
||||
# are not concerned about O(n^2) runtime.
|
||||
for concrete_function in concrete_functions:
|
||||
signature = concrete_function.structured_input_signature
|
||||
flattened = nest.flatten(signature)
|
||||
@ -621,9 +615,14 @@ class Function(object):
|
||||
equal_to_signature = functools.partial(
|
||||
function_lib.is_same_structure, signature, check_values=True)
|
||||
if not any(equal_to_signature(s) for s in seen_signatures):
|
||||
deduplicated_concrete_functions.append(concrete_function)
|
||||
seen_signatures.append(signature)
|
||||
return deduplicated_concrete_functions
|
||||
|
||||
# Re-create concrete functions for these signatures. Re-creating ensures
|
||||
# that if the cache key has changed, the function will be traced again.
|
||||
concrete_functions = []
|
||||
for args, kwargs in seen_signatures:
|
||||
concrete_functions.append(self.get_concrete_function(*args, **kwargs))
|
||||
return concrete_functions
|
||||
|
||||
def get_concrete_function(self, *args, **kwargs):
|
||||
"""Returns a `ConcreteFunction` specialized to inputs and execution context.
|
||||
|
@ -1852,19 +1852,25 @@ class Layer(module.Module):
|
||||
return None
|
||||
return input_masks
|
||||
|
||||
def _call_arg_was_passed(self, arg_name, args, kwargs):
|
||||
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||
if arg_name in kwargs:
|
||||
return True
|
||||
# Ignore `inputs` arg.
|
||||
if arg_name in dict(zip(self._call_fn_args[1:], args)):
|
||||
call_fn_args = self._call_fn_args
|
||||
if not inputs_in_args:
|
||||
# Ignore `inputs` arg.
|
||||
call_fn_args = call_fn_args[1:]
|
||||
if arg_name in dict(zip(call_fn_args, args)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_call_arg_value(self, arg_name, args, kwargs):
|
||||
def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
|
||||
if arg_name in kwargs:
|
||||
return kwargs[arg_name]
|
||||
# Ignore `inputs` arg.
|
||||
args_dict = dict(zip(self._call_fn_args[1:], args))
|
||||
call_fn_args = self._call_fn_args
|
||||
if not inputs_in_args:
|
||||
# Ignore `inputs` arg.
|
||||
call_fn_args = call_fn_args[1:]
|
||||
args_dict = dict(zip(call_fn_args, args))
|
||||
return args_dict[arg_name]
|
||||
|
||||
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
|
||||
|
@ -496,8 +496,10 @@ def maintain_losses(method):
|
||||
layer = self.call_collection.layer
|
||||
training = None
|
||||
# pylint: disable=protected-access
|
||||
if layer._call_arg_was_passed('training', args, kwargs):
|
||||
training = layer._get_call_arg_value('training', args, kwargs)
|
||||
if (args or kwargs) and layer._call_arg_was_passed(
|
||||
'training', args, kwargs, inputs_in_args=True):
|
||||
training = layer._get_call_arg_value(
|
||||
'training', args, kwargs, inputs_in_args=True)
|
||||
# pylint: enable=protected-access
|
||||
original_losses = _reset_layer_losses(layer)
|
||||
with base_layer_utils.call_context().enter(layer, None, True, training):
|
||||
|
Loading…
Reference in New Issue
Block a user