Convert tensors to constants when saving Keras functional and graph models to SavedModel.

This includes changes to backend.get_value() to ensure that the correct graph context created when getting the value of a tensor.

PiperOrigin-RevId: 267498198
This commit is contained in:
Katherine Wu 2019-09-05 18:01:26 -07:00 committed by TensorFlower Gardener
parent 562ffa913e
commit d8f6cd24f2
6 changed files with 63 additions and 8 deletions

View File

@ -3147,7 +3147,7 @@ def get_value(x):
"""
if not tensor_util.is_tensor(x):
return x
if context.executing_eagerly():
if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
return x.numpy()
if not getattr(x, '_in_graph_mode', True):
# This is a variable which was created in an eager context, but is being
@ -3159,6 +3159,7 @@ def get_value(x):
# This method of evaluating works inside the Keras FuncGraph.
return function([], x)(x)
with x.graph.as_default():
return x.eval(session=get_session((x,)))

View File

@ -386,6 +386,7 @@ class CallContext(object):
in_call: Whether currently inside the `call` of a Layer.
training: Whether currently executing in training or inference mode.
in_keras_graph: Whether executing inside the Keras Graph.
saving: Whether currently saving to SavedModel.
"""
def __init__(self):
@ -395,9 +396,10 @@ class CallContext(object):
self.in_call = False
self.training = None
self._in_keras_graph = False
self.saving = False
@tf_contextlib.contextmanager
def enter(self, layer, inputs, build_graph, training):
def enter(self, layer, inputs, build_graph, training, saving=None):
"""Push a Layer and its inputs and state onto the current call context."""
prev_layer = self.layer
prev_inputs = self.inputs
@ -405,6 +407,7 @@ class CallContext(object):
prev_in_call = self.in_call
prev_training = self.training
prev_in_keras_graph = self._in_keras_graph
prev_saving = self.saving
self.layer = layer
self.inputs = inputs
@ -415,6 +418,7 @@ class CallContext(object):
self._in_keras_graph or
(build_graph and
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
self.saving = prev_saving if saving is None else saving
try:
yield
@ -425,6 +429,7 @@ class CallContext(object):
self.in_call = prev_in_call
self.training = prev_training
self._in_keras_graph = prev_in_keras_graph
self.saving = prev_saving
@property
def in_keras_graph(self):

View File

@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph
@ -701,7 +702,9 @@ class Network(base_layer.Layer):
raise NotImplementedError('When subclassing the `Model` class, you should'
' implement a `call` method.')
return self._run_internal_graph(inputs, training=training, mask=mask)
return self._run_internal_graph(
inputs, training=training, mask=mask,
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
def compute_output_shape(self, input_shape):
if not self._is_graph_network:
@ -775,7 +778,8 @@ class Network(base_layer.Layer):
# Return shapes as TensorShapes.
return output_shapes
def _run_internal_graph(self, inputs, training=None, mask=None):
def _run_internal_graph(self, inputs, training=None, mask=None,
convert_kwargs_to_constants=False):
"""Computes output tensors for new inputs.
# Note:
@ -785,6 +789,9 @@ class Network(base_layer.Layer):
inputs: Tensor or nested structure of Tensors.
training: Boolean learning phase.
mask: (Optional) Tensor or nested structure of Tensors.
convert_kwargs_to_constants: Whether to convert Tensor kwargs to
constants. This is used when tracing the model call function during
saving to ensure that external tensors aren't captured.
Returns:
Two lists: output_tensors, output_masks
@ -832,6 +839,9 @@ class Network(base_layer.Layer):
# Ensure `training` arg propagation if applicable.
kwargs = copy.copy(node.arguments) if node.arguments else {}
if convert_kwargs_to_constants:
kwargs = _map_tensors_to_constants(kwargs)
argspec = self._layer_call_argspecs[layer].args
if 'training' in argspec:
kwargs.setdefault('training', training)
@ -1882,6 +1892,15 @@ def _serialize_tensors(kwargs):
return nest.map_structure(_serialize_keras_tensor, kwargs)
def _map_tensors_to_constants(kwargs):
def _map_to_constants(t):
if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor):
return constant_op.constant(backend.get_value(t))
return t
return nest.map_structure(_map_to_constants, kwargs)
def _deserialize_keras_tensors(kwargs, layer_map):
"""Deserializes Keras Tensors passed to `call`.."""

View File

@ -186,7 +186,8 @@ def wrap_layer_functions(layer, serialization_cache):
# Manually trigger traces before restoring the overwritten functions. The
# functions are traced within the layer call context to ensure that layer
# functions (e.g. add_loss) behave as though running in graph mode.
with base_layer_utils.call_context().enter(layer, None, True, None):
with base_layer_utils.call_context().enter(
layer, inputs=None, build_graph=True, training=None, saving=True):
for fn in fns.values():
if fn is not None and fn.input_signature is not None:
fn.get_concrete_function()
@ -504,7 +505,8 @@ def layer_call_wrapper(call_collection, method):
# pylint: enable=protected-access
original_losses = _reset_layer_losses(layer)
with base_layer_utils.call_context().enter(
layer, inputs=inputs, build_graph=False, training=training):
layer, inputs=inputs, build_graph=False, training=training,
saving=True):
ret = method(*args, **kwargs)
_restore_layer_losses(original_losses)
return ret

View File

@ -474,6 +474,29 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
loaded_predictions = loaded.predict(features)
self.assertAllClose(predictions, loaded_predictions)
def testSaveTensorKwarg(self):
class LayerWithTensorKwarg(keras.layers.Layer):
def call(self, inputs, tensor=None):
if tensor is not None:
return inputs * math_ops.cast(tensor, dtypes.float32)
else:
return inputs
t = array_ops.sequence_mask(1)
inputs = keras.layers.Input(shape=(3))
model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t))
input_arr = np.random.random((1, 3)).astype(np.float32)
predictions = model.predict(input_arr)
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
loaded_predictions = loaded.predict(input_arr)
self.assertAllClose(predictions, loaded_predictions)
class TestLayerCallTracing(test.TestCase):

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -136,7 +137,11 @@ def trace_model_call(model, input_signature=None):
# When given a single input, Keras models will call the model on the tensor
# rather than a list consisting of the single tensor.
inputs = args[0] if len(input_signature) == 1 else list(args)
with base_layer_utils.call_context().enter(
model, inputs=inputs, build_graph=False, training=False, saving=True):
outputs_list = nest.flatten(model(inputs=inputs, training=False))
try:
output_names = model.output_names
except AttributeError: