Convert tensors to constants when saving Keras functional and graph models to SavedModel.
This includes changes to backend.get_value() to ensure that the correct graph context created when getting the value of a tensor. PiperOrigin-RevId: 267498198
This commit is contained in:
parent
562ffa913e
commit
d8f6cd24f2
@ -3147,7 +3147,7 @@ def get_value(x):
|
||||
"""
|
||||
if not tensor_util.is_tensor(x):
|
||||
return x
|
||||
if context.executing_eagerly():
|
||||
if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
|
||||
return x.numpy()
|
||||
if not getattr(x, '_in_graph_mode', True):
|
||||
# This is a variable which was created in an eager context, but is being
|
||||
@ -3159,7 +3159,8 @@ def get_value(x):
|
||||
# This method of evaluating works inside the Keras FuncGraph.
|
||||
return function([], x)(x)
|
||||
|
||||
return x.eval(session=get_session((x,)))
|
||||
with x.graph.as_default():
|
||||
return x.eval(session=get_session((x,)))
|
||||
|
||||
|
||||
@keras_export('keras.backend.batch_get_value')
|
||||
|
@ -386,6 +386,7 @@ class CallContext(object):
|
||||
in_call: Whether currently inside the `call` of a Layer.
|
||||
training: Whether currently executing in training or inference mode.
|
||||
in_keras_graph: Whether executing inside the Keras Graph.
|
||||
saving: Whether currently saving to SavedModel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -395,9 +396,10 @@ class CallContext(object):
|
||||
self.in_call = False
|
||||
self.training = None
|
||||
self._in_keras_graph = False
|
||||
self.saving = False
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def enter(self, layer, inputs, build_graph, training):
|
||||
def enter(self, layer, inputs, build_graph, training, saving=None):
|
||||
"""Push a Layer and its inputs and state onto the current call context."""
|
||||
prev_layer = self.layer
|
||||
prev_inputs = self.inputs
|
||||
@ -405,6 +407,7 @@ class CallContext(object):
|
||||
prev_in_call = self.in_call
|
||||
prev_training = self.training
|
||||
prev_in_keras_graph = self._in_keras_graph
|
||||
prev_saving = self.saving
|
||||
|
||||
self.layer = layer
|
||||
self.inputs = inputs
|
||||
@ -415,6 +418,7 @@ class CallContext(object):
|
||||
self._in_keras_graph or
|
||||
(build_graph and
|
||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
||||
self.saving = prev_saving if saving is None else saving
|
||||
|
||||
try:
|
||||
yield
|
||||
@ -425,6 +429,7 @@ class CallContext(object):
|
||||
self.in_call = prev_in_call
|
||||
self.training = prev_training
|
||||
self._in_keras_graph = prev_in_keras_graph
|
||||
self.saving = prev_saving
|
||||
|
||||
@property
|
||||
def in_keras_graph(self):
|
||||
|
@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import func_graph
|
||||
@ -701,7 +702,9 @@ class Network(base_layer.Layer):
|
||||
raise NotImplementedError('When subclassing the `Model` class, you should'
|
||||
' implement a `call` method.')
|
||||
|
||||
return self._run_internal_graph(inputs, training=training, mask=mask)
|
||||
return self._run_internal_graph(
|
||||
inputs, training=training, mask=mask,
|
||||
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
if not self._is_graph_network:
|
||||
@ -775,7 +778,8 @@ class Network(base_layer.Layer):
|
||||
# Return shapes as TensorShapes.
|
||||
return output_shapes
|
||||
|
||||
def _run_internal_graph(self, inputs, training=None, mask=None):
|
||||
def _run_internal_graph(self, inputs, training=None, mask=None,
|
||||
convert_kwargs_to_constants=False):
|
||||
"""Computes output tensors for new inputs.
|
||||
|
||||
# Note:
|
||||
@ -785,6 +789,9 @@ class Network(base_layer.Layer):
|
||||
inputs: Tensor or nested structure of Tensors.
|
||||
training: Boolean learning phase.
|
||||
mask: (Optional) Tensor or nested structure of Tensors.
|
||||
convert_kwargs_to_constants: Whether to convert Tensor kwargs to
|
||||
constants. This is used when tracing the model call function during
|
||||
saving to ensure that external tensors aren't captured.
|
||||
|
||||
Returns:
|
||||
Two lists: output_tensors, output_masks
|
||||
@ -832,6 +839,9 @@ class Network(base_layer.Layer):
|
||||
|
||||
# Ensure `training` arg propagation if applicable.
|
||||
kwargs = copy.copy(node.arguments) if node.arguments else {}
|
||||
if convert_kwargs_to_constants:
|
||||
kwargs = _map_tensors_to_constants(kwargs)
|
||||
|
||||
argspec = self._layer_call_argspecs[layer].args
|
||||
if 'training' in argspec:
|
||||
kwargs.setdefault('training', training)
|
||||
@ -1882,6 +1892,15 @@ def _serialize_tensors(kwargs):
|
||||
|
||||
return nest.map_structure(_serialize_keras_tensor, kwargs)
|
||||
|
||||
def _map_tensors_to_constants(kwargs):
|
||||
|
||||
def _map_to_constants(t):
|
||||
if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor):
|
||||
return constant_op.constant(backend.get_value(t))
|
||||
return t
|
||||
|
||||
return nest.map_structure(_map_to_constants, kwargs)
|
||||
|
||||
|
||||
def _deserialize_keras_tensors(kwargs, layer_map):
|
||||
"""Deserializes Keras Tensors passed to `call`.."""
|
||||
|
@ -186,7 +186,8 @@ def wrap_layer_functions(layer, serialization_cache):
|
||||
# Manually trigger traces before restoring the overwritten functions. The
|
||||
# functions are traced within the layer call context to ensure that layer
|
||||
# functions (e.g. add_loss) behave as though running in graph mode.
|
||||
with base_layer_utils.call_context().enter(layer, None, True, None):
|
||||
with base_layer_utils.call_context().enter(
|
||||
layer, inputs=None, build_graph=True, training=None, saving=True):
|
||||
for fn in fns.values():
|
||||
if fn is not None and fn.input_signature is not None:
|
||||
fn.get_concrete_function()
|
||||
@ -504,7 +505,8 @@ def layer_call_wrapper(call_collection, method):
|
||||
# pylint: enable=protected-access
|
||||
original_losses = _reset_layer_losses(layer)
|
||||
with base_layer_utils.call_context().enter(
|
||||
layer, inputs=inputs, build_graph=False, training=training):
|
||||
layer, inputs=inputs, build_graph=False, training=training,
|
||||
saving=True):
|
||||
ret = method(*args, **kwargs)
|
||||
_restore_layer_losses(original_losses)
|
||||
return ret
|
||||
|
@ -474,6 +474,29 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
loaded_predictions = loaded.predict(features)
|
||||
self.assertAllClose(predictions, loaded_predictions)
|
||||
|
||||
def testSaveTensorKwarg(self):
|
||||
|
||||
class LayerWithTensorKwarg(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, tensor=None):
|
||||
if tensor is not None:
|
||||
return inputs * math_ops.cast(tensor, dtypes.float32)
|
||||
else:
|
||||
return inputs
|
||||
|
||||
t = array_ops.sequence_mask(1)
|
||||
inputs = keras.layers.Input(shape=(3))
|
||||
model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t))
|
||||
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
predictions = model.predict(input_arr)
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
loaded_predictions = loaded.predict(input_arr)
|
||||
self.assertAllClose(predictions, loaded_predictions)
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase):
|
||||
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -136,7 +137,11 @@ def trace_model_call(model, input_signature=None):
|
||||
# When given a single input, Keras models will call the model on the tensor
|
||||
# rather than a list consisting of the single tensor.
|
||||
inputs = args[0] if len(input_signature) == 1 else list(args)
|
||||
outputs_list = nest.flatten(model(inputs=inputs, training=False))
|
||||
|
||||
with base_layer_utils.call_context().enter(
|
||||
model, inputs=inputs, build_graph=False, training=False, saving=True):
|
||||
outputs_list = nest.flatten(model(inputs=inputs, training=False))
|
||||
|
||||
try:
|
||||
output_names = model.output_names
|
||||
except AttributeError:
|
||||
|
Loading…
Reference in New Issue
Block a user