tf.saved_model: Re-create concrete functions at saving time - this ensures that if the cache key has changed, the function will be traced again. For example, when the model is run with distribution strategy, but saved without it, we want to the saved version to trace again without strategy.

PiperOrigin-RevId: 261074074
This commit is contained in:
Priya Gupta 2019-08-01 00:37:19 -07:00 committed by TensorFlower Gardener
parent 415771767b
commit 91d7124c3c
7 changed files with 158 additions and 22 deletions

View File

@ -22,9 +22,12 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute.model_collection import model_collection_base
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.module import module
from tensorflow.python.ops import variables
_BATCH_SIZE = 10
@ -131,3 +134,27 @@ class SimpleSubclassModel(model_collection_base.ModelAndInput):
def get_batch_size(self):
return _BATCH_SIZE
class _SimpleModule(module.Module):
def __init__(self):
self.v = variables.Variable(3.0)
@def_function.function
def __call__(self, x):
return self.v * x
class SimpleTFModuleModel(model_collection_base.ModelAndInput):
"""A simple model based on tf.Module and its data."""
def get_model(self, **kwargs):
model = _SimpleModule()
return model, 'foo'
def get_data(self):
return _get_data_for_simple_models()
def get_batch_size(self):
return _BATCH_SIZE

View File

@ -29,3 +29,6 @@ simple_sequential_model = combinations.NamedObject(
simple_subclass_model = combinations.NamedObject(
"SimpleSubclassModel", simple_models.SimpleSubclassModel())
simple_tfmodule_model = combinations.NamedObject(
"SimpleTFModuleModel", simple_models.SimpleTFModuleModel())

View File

@ -21,14 +21,16 @@ from __future__ import print_function
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import saved_model_test_base as test_base
from tensorflow.python.eager import test
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import saved_model
class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
class SavedModelKerasModelTest(test_base.TestSavedModelBase):
def setUp(self):
self._root_dir = 'saved_model_save_load'
super(SavedModelSaveAndLoadTest, self).setUp()
super(SavedModelKerasModelTest, self).setUp()
def _save_model(self, model, saved_dir):
saved_model.save(model, saved_dir)
@ -53,6 +55,7 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
distribution, save_in_scope,
experimental_run_tf_function):
if save_in_scope:
# TODO(b/134703272): Unskip this test when fixed.
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
'supported.'))
self.run_test_save_strategy_restore_no_strategy(
@ -68,6 +71,80 @@ class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
save_in_scope,
experimental_run_tf_function):
if save_in_scope:
# TODO(b/134703272): Unskip this test when fixed.
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
'supported.'))
self.run_test_save_strategy_restore_strategy(model_and_input,
distribution_for_saving,
distribution_for_restoring,
save_in_scope,
experimental_run_tf_function)
class SavedModelTFModuleTest(test_base.TestSavedModelBase):
def setUp(self):
self._root_dir = 'saved_model_save_load'
super(SavedModelTFModuleTest, self).setUp()
def _train_model(self, model, x_train, y_train, batch_size):
pass
def _predict_with_model(self, distribution, model, predict_dataset):
if distribution:
dist_predict_dataset = distribution.experimental_distribute_dataset(
predict_dataset)
per_replica_predict_data = next(iter(dist_predict_dataset))
result = distribution.experimental_run_v2(
model, args=(per_replica_predict_data,))
# Convert the per_replica value to a list, then concatenate them
reduced = distribution.experimental_local_results(result)
concat = array_ops.concat(reduced, 0)
return concat
else:
return model(next(iter(predict_dataset)))
def _save_model(self, model, saved_dir):
call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None))
saved_model.save(model, saved_dir, signatures=call)
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
output_name, experimental_run_tf_function):
del output_name, experimental_run_tf_function
model = saved_model.load(saved_dir)
return self._predict_with_model(distribution, model, predict_dataset)
@combinations.generate(test_base.tfmodule_models_with_strategies())
def test_save_no_strategy_restore_strategy(self, model_and_input,
distribution,
experimental_run_tf_function):
self.run_test_save_no_strategy_restore_strategy(
model_and_input, distribution, experimental_run_tf_function)
@combinations.generate(
combinations.times(test_base.tfmodule_models_with_strategies(),
combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_no_strategy(
self, model_and_input, distribution, save_in_scope,
experimental_run_tf_function):
if save_in_scope:
# TODO(b/134703272): Unskip this test when fixed.
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
'supported.'))
self.run_test_save_strategy_restore_no_strategy(
model_and_input, distribution, save_in_scope,
experimental_run_tf_function)
@combinations.generate(
combinations.times(test_base.tfmodule_models_with_strategy_pairs(),
combinations.combine(save_in_scope=[True, False])))
def test_save_strategy_restore_strategy(self, model_and_input,
distribution_for_saving,
distribution_for_restoring,
save_in_scope,
experimental_run_tf_function):
if save_in_scope:
# TODO(b/134703272): Unskip this test when fixed.
self.skipTest(('Saving model within tf.distribute.Strategy scope is not ',
'supported.'))
self.run_test_save_strategy_restore_strategy(model_and_input,

View File

@ -75,6 +75,23 @@ def simple_models_with_strategy_pairs():
experimental_run_tf_function=[True, False])
def tfmodule_models_with_strategies():
return combinations.combine(
model_and_input=[model_combinations.simple_tfmodule_model],
distribution=strategies_minus_tpu,
mode=['eager'],
experimental_run_tf_function=[True])
def tfmodule_models_with_strategy_pairs():
return combinations.combine(
model_and_input=[model_combinations.simple_tfmodule_model],
distribution_for_saving=strategies_minus_tpu,
distribution_for_restoring=strategies_minus_tpu,
mode=['eager'],
experimental_run_tf_function=[True])
def load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset,
output_name):
"""Loads a saved_model using tf.saved_model API, and runs it."""
@ -146,6 +163,9 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
# Train the model for 1 epoch
model.fit(x=training_dataset, epochs=1, steps_per_epoch=100)
def _predict_with_model(self, distribution, model, predict_dataset):
return model.predict(predict_dataset, steps=PREDICT_STEPS)
def _get_predict_dataset(self, x_predict, batch_size):
predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
predict_dataset = predict_dataset.repeat()
@ -163,10 +183,10 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
experimental_run_tf_function=experimental_run_tf_function)
x_train, y_train, x_predict = model_and_input.get_data()
batch_size = model_and_input.get_batch_size()
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
self._train_model(model, x_train, y_train, batch_size)
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
result_before_save = self._predict_with_model(None, model, predict_dataset)
self._save_model(model, saved_dir)
@ -195,7 +215,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
self._train_model(model, x_train, y_train, batch_size)
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
result_before_save = self._predict_with_model(
distribution, model, predict_dataset)
if save_in_scope:
with distribution.scope():
@ -229,7 +250,8 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
self._train_model(model, x_train, y_train, batch_size)
predict_dataset = self._get_predict_dataset(x_predict, batch_size)
result_before_save = model.predict(predict_dataset, steps=PREDICT_STEPS)
result_before_save = self._predict_with_model(
distribution_for_saving, model, predict_dataset)
if save_in_scope:
with distribution_for_saving.scope():

View File

@ -603,13 +603,7 @@ class Function(object):
concrete_functions.extend(
self._stateless_fn._function_cache.all_values())
# pylint: enable=protected-access
deduplicated_concrete_functions = []
seen_signatures = []
# We are using a list so that:
# - the returned collection is deterministic, and
# - we can use a custom equality operator (is_same_structure).
# This is run only at serialization time on likely very small inputs so we
# are not concerned about O(n^2) runtime.
for concrete_function in concrete_functions:
signature = concrete_function.structured_input_signature
flattened = nest.flatten(signature)
@ -621,9 +615,14 @@ class Function(object):
equal_to_signature = functools.partial(
function_lib.is_same_structure, signature, check_values=True)
if not any(equal_to_signature(s) for s in seen_signatures):
deduplicated_concrete_functions.append(concrete_function)
seen_signatures.append(signature)
return deduplicated_concrete_functions
# Re-create concrete functions for these signatures. Re-creating ensures
# that if the cache key has changed, the function will be traced again.
concrete_functions = []
for args, kwargs in seen_signatures:
concrete_functions.append(self.get_concrete_function(*args, **kwargs))
return concrete_functions
def get_concrete_function(self, *args, **kwargs):
"""Returns a `ConcreteFunction` specialized to inputs and execution context.

View File

@ -1852,19 +1852,25 @@ class Layer(module.Module):
return None
return input_masks
def _call_arg_was_passed(self, arg_name, args, kwargs):
def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
if arg_name in kwargs:
return True
# Ignore `inputs` arg.
if arg_name in dict(zip(self._call_fn_args[1:], args)):
call_fn_args = self._call_fn_args
if not inputs_in_args:
# Ignore `inputs` arg.
call_fn_args = call_fn_args[1:]
if arg_name in dict(zip(call_fn_args, args)):
return True
return False
def _get_call_arg_value(self, arg_name, args, kwargs):
def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
if arg_name in kwargs:
return kwargs[arg_name]
# Ignore `inputs` arg.
args_dict = dict(zip(self._call_fn_args[1:], args))
call_fn_args = self._call_fn_args
if not inputs_in_args:
# Ignore `inputs` arg.
call_fn_args = call_fn_args[1:]
args_dict = dict(zip(call_fn_args, args))
return args_dict[arg_name]
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):

View File

@ -496,8 +496,10 @@ def maintain_losses(method):
layer = self.call_collection.layer
training = None
# pylint: disable=protected-access
if layer._call_arg_was_passed('training', args, kwargs):
training = layer._get_call_arg_value('training', args, kwargs)
if (args or kwargs) and layer._call_arg_was_passed(
'training', args, kwargs, inputs_in_args=True):
training = layer._get_call_arg_value(
'training', args, kwargs, inputs_in_args=True)
# pylint: enable=protected-access
original_losses = _reset_layer_losses(layer)
with base_layer_utils.call_context().enter(layer, None, True, training):