From d8f6cd24f28467f95b89ddf7bfbb5e0a2cafeb61 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Thu, 5 Sep 2019 18:01:26 -0700 Subject: [PATCH] Convert tensors to constants when saving Keras functional and graph models to SavedModel. This includes changes to backend.get_value() to ensure that the correct graph context created when getting the value of a tensor. PiperOrigin-RevId: 267498198 --- tensorflow/python/keras/backend.py | 5 ++-- .../python/keras/engine/base_layer_utils.py | 7 +++++- tensorflow/python/keras/engine/network.py | 23 +++++++++++++++++-- .../keras/saving/saved_model/save_impl.py | 6 +++-- .../saving/saved_model/saved_model_test.py | 23 +++++++++++++++++++ .../python/keras/saving/saving_utils.py | 7 +++++- 6 files changed, 63 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 2ca2e7382e1..947aa43e9c9 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3147,7 +3147,7 @@ def get_value(x): """ if not tensor_util.is_tensor(x): return x - if context.executing_eagerly(): + if context.executing_eagerly() or isinstance(x, ops.EagerTensor): return x.numpy() if not getattr(x, '_in_graph_mode', True): # This is a variable which was created in an eager context, but is being @@ -3159,7 +3159,8 @@ def get_value(x): # This method of evaluating works inside the Keras FuncGraph. return function([], x)(x) - return x.eval(session=get_session((x,))) + with x.graph.as_default(): + return x.eval(session=get_session((x,))) @keras_export('keras.backend.batch_get_value') diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index fbcc452ecb8..8a88932b343 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -386,6 +386,7 @@ class CallContext(object): in_call: Whether currently inside the `call` of a Layer. training: Whether currently executing in training or inference mode. in_keras_graph: Whether executing inside the Keras Graph. + saving: Whether currently saving to SavedModel. """ def __init__(self): @@ -395,9 +396,10 @@ class CallContext(object): self.in_call = False self.training = None self._in_keras_graph = False + self.saving = False @tf_contextlib.contextmanager - def enter(self, layer, inputs, build_graph, training): + def enter(self, layer, inputs, build_graph, training, saving=None): """Push a Layer and its inputs and state onto the current call context.""" prev_layer = self.layer prev_inputs = self.inputs @@ -405,6 +407,7 @@ class CallContext(object): prev_in_call = self.in_call prev_training = self.training prev_in_keras_graph = self._in_keras_graph + prev_saving = self.saving self.layer = layer self.inputs = inputs @@ -415,6 +418,7 @@ class CallContext(object): self._in_keras_graph or (build_graph and getattr(backend.get_graph(), 'name', None) == 'keras_graph')) + self.saving = prev_saving if saving is None else saving try: yield @@ -425,6 +429,7 @@ class CallContext(object): self.in_call = prev_in_call self.training = prev_training self._in_keras_graph = prev_in_keras_graph + self.saving = prev_saving @property def in_keras_graph(self): diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 9655c51f708..80cc0790286 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl from tensorflow.python.framework import func_graph @@ -701,7 +702,9 @@ class Network(base_layer.Layer): raise NotImplementedError('When subclassing the `Model` class, you should' ' implement a `call` method.') - return self._run_internal_graph(inputs, training=training, mask=mask) + return self._run_internal_graph( + inputs, training=training, mask=mask, + convert_kwargs_to_constants=base_layer_utils.call_context().saving) def compute_output_shape(self, input_shape): if not self._is_graph_network: @@ -775,7 +778,8 @@ class Network(base_layer.Layer): # Return shapes as TensorShapes. return output_shapes - def _run_internal_graph(self, inputs, training=None, mask=None): + def _run_internal_graph(self, inputs, training=None, mask=None, + convert_kwargs_to_constants=False): """Computes output tensors for new inputs. # Note: @@ -785,6 +789,9 @@ class Network(base_layer.Layer): inputs: Tensor or nested structure of Tensors. training: Boolean learning phase. mask: (Optional) Tensor or nested structure of Tensors. + convert_kwargs_to_constants: Whether to convert Tensor kwargs to + constants. This is used when tracing the model call function during + saving to ensure that external tensors aren't captured. Returns: Two lists: output_tensors, output_masks @@ -832,6 +839,9 @@ class Network(base_layer.Layer): # Ensure `training` arg propagation if applicable. kwargs = copy.copy(node.arguments) if node.arguments else {} + if convert_kwargs_to_constants: + kwargs = _map_tensors_to_constants(kwargs) + argspec = self._layer_call_argspecs[layer].args if 'training' in argspec: kwargs.setdefault('training', training) @@ -1882,6 +1892,15 @@ def _serialize_tensors(kwargs): return nest.map_structure(_serialize_keras_tensor, kwargs) +def _map_tensors_to_constants(kwargs): + + def _map_to_constants(t): + if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor): + return constant_op.constant(backend.get_value(t)) + return t + + return nest.map_structure(_map_to_constants, kwargs) + def _deserialize_keras_tensors(kwargs, layer_map): """Deserializes Keras Tensors passed to `call`..""" diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py index 3923eaad533..6c766dc0ec3 100644 --- a/tensorflow/python/keras/saving/saved_model/save_impl.py +++ b/tensorflow/python/keras/saving/saved_model/save_impl.py @@ -186,7 +186,8 @@ def wrap_layer_functions(layer, serialization_cache): # Manually trigger traces before restoring the overwritten functions. The # functions are traced within the layer call context to ensure that layer # functions (e.g. add_loss) behave as though running in graph mode. - with base_layer_utils.call_context().enter(layer, None, True, None): + with base_layer_utils.call_context().enter( + layer, inputs=None, build_graph=True, training=None, saving=True): for fn in fns.values(): if fn is not None and fn.input_signature is not None: fn.get_concrete_function() @@ -504,7 +505,8 @@ def layer_call_wrapper(call_collection, method): # pylint: enable=protected-access original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter( - layer, inputs=inputs, build_graph=False, training=training): + layer, inputs=inputs, build_graph=False, training=training, + saving=True): ret = method(*args, **kwargs) _restore_layer_losses(original_losses) return ret diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index f13f3fa2dd9..18500443c23 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -474,6 +474,29 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): loaded_predictions = loaded.predict(features) self.assertAllClose(predictions, loaded_predictions) + def testSaveTensorKwarg(self): + + class LayerWithTensorKwarg(keras.layers.Layer): + + def call(self, inputs, tensor=None): + if tensor is not None: + return inputs * math_ops.cast(tensor, dtypes.float32) + else: + return inputs + + t = array_ops.sequence_mask(1) + inputs = keras.layers.Input(shape=(3)) + model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t)) + + input_arr = np.random.random((1, 3)).astype(np.float32) + predictions = model.predict(input_arr) + + saved_model_dir = self._save_model_dir() + model.save(saved_model_dir, save_format='tf') + loaded = keras_load.load(saved_model_dir) + loaded_predictions = loaded.predict(input_arr) + self.assertAllClose(predictions, loaded_predictions) + class TestLayerCallTracing(test.TestCase): diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index a1019022c01..d5bc66aa76b 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import optimizers +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -136,7 +137,11 @@ def trace_model_call(model, input_signature=None): # When given a single input, Keras models will call the model on the tensor # rather than a list consisting of the single tensor. inputs = args[0] if len(input_signature) == 1 else list(args) - outputs_list = nest.flatten(model(inputs=inputs, training=False)) + + with base_layer_utils.call_context().enter( + model, inputs=inputs, build_graph=False, training=False, saving=True): + outputs_list = nest.flatten(model(inputs=inputs, training=False)) + try: output_names = model.output_names except AttributeError: