Convert tensors to constants when saving Keras functional and graph models to SavedModel.
This includes changes to backend.get_value() to ensure that the correct graph context created when getting the value of a tensor. PiperOrigin-RevId: 267498198
This commit is contained in:
parent
562ffa913e
commit
d8f6cd24f2
@ -3147,7 +3147,7 @@ def get_value(x):
|
|||||||
"""
|
"""
|
||||||
if not tensor_util.is_tensor(x):
|
if not tensor_util.is_tensor(x):
|
||||||
return x
|
return x
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly() or isinstance(x, ops.EagerTensor):
|
||||||
return x.numpy()
|
return x.numpy()
|
||||||
if not getattr(x, '_in_graph_mode', True):
|
if not getattr(x, '_in_graph_mode', True):
|
||||||
# This is a variable which was created in an eager context, but is being
|
# This is a variable which was created in an eager context, but is being
|
||||||
@ -3159,6 +3159,7 @@ def get_value(x):
|
|||||||
# This method of evaluating works inside the Keras FuncGraph.
|
# This method of evaluating works inside the Keras FuncGraph.
|
||||||
return function([], x)(x)
|
return function([], x)(x)
|
||||||
|
|
||||||
|
with x.graph.as_default():
|
||||||
return x.eval(session=get_session((x,)))
|
return x.eval(session=get_session((x,)))
|
||||||
|
|
||||||
|
|
||||||
|
@ -386,6 +386,7 @@ class CallContext(object):
|
|||||||
in_call: Whether currently inside the `call` of a Layer.
|
in_call: Whether currently inside the `call` of a Layer.
|
||||||
training: Whether currently executing in training or inference mode.
|
training: Whether currently executing in training or inference mode.
|
||||||
in_keras_graph: Whether executing inside the Keras Graph.
|
in_keras_graph: Whether executing inside the Keras Graph.
|
||||||
|
saving: Whether currently saving to SavedModel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -395,9 +396,10 @@ class CallContext(object):
|
|||||||
self.in_call = False
|
self.in_call = False
|
||||||
self.training = None
|
self.training = None
|
||||||
self._in_keras_graph = False
|
self._in_keras_graph = False
|
||||||
|
self.saving = False
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
def enter(self, layer, inputs, build_graph, training):
|
def enter(self, layer, inputs, build_graph, training, saving=None):
|
||||||
"""Push a Layer and its inputs and state onto the current call context."""
|
"""Push a Layer and its inputs and state onto the current call context."""
|
||||||
prev_layer = self.layer
|
prev_layer = self.layer
|
||||||
prev_inputs = self.inputs
|
prev_inputs = self.inputs
|
||||||
@ -405,6 +407,7 @@ class CallContext(object):
|
|||||||
prev_in_call = self.in_call
|
prev_in_call = self.in_call
|
||||||
prev_training = self.training
|
prev_training = self.training
|
||||||
prev_in_keras_graph = self._in_keras_graph
|
prev_in_keras_graph = self._in_keras_graph
|
||||||
|
prev_saving = self.saving
|
||||||
|
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
@ -415,6 +418,7 @@ class CallContext(object):
|
|||||||
self._in_keras_graph or
|
self._in_keras_graph or
|
||||||
(build_graph and
|
(build_graph and
|
||||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
||||||
|
self.saving = prev_saving if saving is None else saving
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
@ -425,6 +429,7 @@ class CallContext(object):
|
|||||||
self.in_call = prev_in_call
|
self.in_call = prev_in_call
|
||||||
self.training = prev_training
|
self.training = prev_training
|
||||||
self._in_keras_graph = prev_in_keras_graph
|
self._in_keras_graph = prev_in_keras_graph
|
||||||
|
self.saving = prev_saving
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def in_keras_graph(self):
|
def in_keras_graph(self):
|
||||||
|
@ -31,6 +31,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
@ -701,7 +702,9 @@ class Network(base_layer.Layer):
|
|||||||
raise NotImplementedError('When subclassing the `Model` class, you should'
|
raise NotImplementedError('When subclassing the `Model` class, you should'
|
||||||
' implement a `call` method.')
|
' implement a `call` method.')
|
||||||
|
|
||||||
return self._run_internal_graph(inputs, training=training, mask=mask)
|
return self._run_internal_graph(
|
||||||
|
inputs, training=training, mask=mask,
|
||||||
|
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
if not self._is_graph_network:
|
if not self._is_graph_network:
|
||||||
@ -775,7 +778,8 @@ class Network(base_layer.Layer):
|
|||||||
# Return shapes as TensorShapes.
|
# Return shapes as TensorShapes.
|
||||||
return output_shapes
|
return output_shapes
|
||||||
|
|
||||||
def _run_internal_graph(self, inputs, training=None, mask=None):
|
def _run_internal_graph(self, inputs, training=None, mask=None,
|
||||||
|
convert_kwargs_to_constants=False):
|
||||||
"""Computes output tensors for new inputs.
|
"""Computes output tensors for new inputs.
|
||||||
|
|
||||||
# Note:
|
# Note:
|
||||||
@ -785,6 +789,9 @@ class Network(base_layer.Layer):
|
|||||||
inputs: Tensor or nested structure of Tensors.
|
inputs: Tensor or nested structure of Tensors.
|
||||||
training: Boolean learning phase.
|
training: Boolean learning phase.
|
||||||
mask: (Optional) Tensor or nested structure of Tensors.
|
mask: (Optional) Tensor or nested structure of Tensors.
|
||||||
|
convert_kwargs_to_constants: Whether to convert Tensor kwargs to
|
||||||
|
constants. This is used when tracing the model call function during
|
||||||
|
saving to ensure that external tensors aren't captured.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Two lists: output_tensors, output_masks
|
Two lists: output_tensors, output_masks
|
||||||
@ -832,6 +839,9 @@ class Network(base_layer.Layer):
|
|||||||
|
|
||||||
# Ensure `training` arg propagation if applicable.
|
# Ensure `training` arg propagation if applicable.
|
||||||
kwargs = copy.copy(node.arguments) if node.arguments else {}
|
kwargs = copy.copy(node.arguments) if node.arguments else {}
|
||||||
|
if convert_kwargs_to_constants:
|
||||||
|
kwargs = _map_tensors_to_constants(kwargs)
|
||||||
|
|
||||||
argspec = self._layer_call_argspecs[layer].args
|
argspec = self._layer_call_argspecs[layer].args
|
||||||
if 'training' in argspec:
|
if 'training' in argspec:
|
||||||
kwargs.setdefault('training', training)
|
kwargs.setdefault('training', training)
|
||||||
@ -1882,6 +1892,15 @@ def _serialize_tensors(kwargs):
|
|||||||
|
|
||||||
return nest.map_structure(_serialize_keras_tensor, kwargs)
|
return nest.map_structure(_serialize_keras_tensor, kwargs)
|
||||||
|
|
||||||
|
def _map_tensors_to_constants(kwargs):
|
||||||
|
|
||||||
|
def _map_to_constants(t):
|
||||||
|
if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor):
|
||||||
|
return constant_op.constant(backend.get_value(t))
|
||||||
|
return t
|
||||||
|
|
||||||
|
return nest.map_structure(_map_to_constants, kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _deserialize_keras_tensors(kwargs, layer_map):
|
def _deserialize_keras_tensors(kwargs, layer_map):
|
||||||
"""Deserializes Keras Tensors passed to `call`.."""
|
"""Deserializes Keras Tensors passed to `call`.."""
|
||||||
|
@ -186,7 +186,8 @@ def wrap_layer_functions(layer, serialization_cache):
|
|||||||
# Manually trigger traces before restoring the overwritten functions. The
|
# Manually trigger traces before restoring the overwritten functions. The
|
||||||
# functions are traced within the layer call context to ensure that layer
|
# functions are traced within the layer call context to ensure that layer
|
||||||
# functions (e.g. add_loss) behave as though running in graph mode.
|
# functions (e.g. add_loss) behave as though running in graph mode.
|
||||||
with base_layer_utils.call_context().enter(layer, None, True, None):
|
with base_layer_utils.call_context().enter(
|
||||||
|
layer, inputs=None, build_graph=True, training=None, saving=True):
|
||||||
for fn in fns.values():
|
for fn in fns.values():
|
||||||
if fn is not None and fn.input_signature is not None:
|
if fn is not None and fn.input_signature is not None:
|
||||||
fn.get_concrete_function()
|
fn.get_concrete_function()
|
||||||
@ -504,7 +505,8 @@ def layer_call_wrapper(call_collection, method):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
original_losses = _reset_layer_losses(layer)
|
original_losses = _reset_layer_losses(layer)
|
||||||
with base_layer_utils.call_context().enter(
|
with base_layer_utils.call_context().enter(
|
||||||
layer, inputs=inputs, build_graph=False, training=training):
|
layer, inputs=inputs, build_graph=False, training=training,
|
||||||
|
saving=True):
|
||||||
ret = method(*args, **kwargs)
|
ret = method(*args, **kwargs)
|
||||||
_restore_layer_losses(original_losses)
|
_restore_layer_losses(original_losses)
|
||||||
return ret
|
return ret
|
||||||
|
@ -474,6 +474,29 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
|||||||
loaded_predictions = loaded.predict(features)
|
loaded_predictions = loaded.predict(features)
|
||||||
self.assertAllClose(predictions, loaded_predictions)
|
self.assertAllClose(predictions, loaded_predictions)
|
||||||
|
|
||||||
|
def testSaveTensorKwarg(self):
|
||||||
|
|
||||||
|
class LayerWithTensorKwarg(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs, tensor=None):
|
||||||
|
if tensor is not None:
|
||||||
|
return inputs * math_ops.cast(tensor, dtypes.float32)
|
||||||
|
else:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
t = array_ops.sequence_mask(1)
|
||||||
|
inputs = keras.layers.Input(shape=(3))
|
||||||
|
model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t))
|
||||||
|
|
||||||
|
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||||
|
predictions = model.predict(input_arr)
|
||||||
|
|
||||||
|
saved_model_dir = self._save_model_dir()
|
||||||
|
model.save(saved_model_dir, save_format='tf')
|
||||||
|
loaded = keras_load.load(saved_model_dir)
|
||||||
|
loaded_predictions = loaded.predict(input_arr)
|
||||||
|
self.assertAllClose(predictions, loaded_predictions)
|
||||||
|
|
||||||
|
|
||||||
class TestLayerCallTracing(test.TestCase):
|
class TestLayerCallTracing(test.TestCase):
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_spec
|
|||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import losses
|
from tensorflow.python.keras import losses
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizers
|
||||||
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -136,7 +137,11 @@ def trace_model_call(model, input_signature=None):
|
|||||||
# When given a single input, Keras models will call the model on the tensor
|
# When given a single input, Keras models will call the model on the tensor
|
||||||
# rather than a list consisting of the single tensor.
|
# rather than a list consisting of the single tensor.
|
||||||
inputs = args[0] if len(input_signature) == 1 else list(args)
|
inputs = args[0] if len(input_signature) == 1 else list(args)
|
||||||
|
|
||||||
|
with base_layer_utils.call_context().enter(
|
||||||
|
model, inputs=inputs, build_graph=False, training=False, saving=True):
|
||||||
outputs_list = nest.flatten(model(inputs=inputs, training=False))
|
outputs_list = nest.flatten(model(inputs=inputs, training=False))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output_names = model.output_names
|
output_names = model.output_names
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
Loading…
Reference in New Issue
Block a user