Automated rollback of commit 65b507e8a1

PiperOrigin-RevId: 252924374
This commit is contained in:
Katherine Wu 2019-06-12 16:30:29 -07:00 committed by Kathy Wu
parent 319e32730a
commit 974ff69e6e
7 changed files with 272 additions and 134 deletions

View File

@ -328,9 +328,9 @@ class FromSavedModelTest(TestModels):
self.assertIn('This converter can only convert a single ConcreteFunction',
str(error.exception))
@test_util.run_v2_only
def testKerasSequentialModel(self):
"""Test a simple sequential tf.Keras model."""
self.skipTest('b/134660903')
input_data = constant_op.constant(1., shape=[1, 1])
x = np.array([[1.], [2.]])

View File

@ -164,7 +164,7 @@ def _compatible_shapes(flat_relaxed, flat_to_check):
for relaxed, to_check in zip(flat_relaxed, flat_to_check))
def _common_shape(x, y):
def common_shape(x, y):
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
if x is None != y is None:
raise RuntimeError(
@ -1577,7 +1577,7 @@ class Function(object):
"relaxed_arg_shapes len: %d vs. %d"
% (len(arg_shapes), len(relaxed_arg_shapes)))
relaxed_arg_shapes = [
_common_shape(x, y) for (x, y) in zip(
common_shape(x, y) for (x, y) in zip(
arg_shapes, relaxed_arg_shapes)]
self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = (
relaxed_arg_shapes)
@ -1679,8 +1679,9 @@ def register(func, *args, **kwargs):
def validate_signature(signature):
if any(not isinstance(arg, tensor_spec.TensorSpec)
for arg in nest.flatten(signature, expand_composites=True)):
raise TypeError("Invalid input_signature %s; input_signature must be "
"a possibly nested sequence of TensorSpec objects.")
raise TypeError("Invalid input_signature {}; input_signature must be "
"a possibly nested sequence of TensorSpec objects."
.format(signature))
def defun(func=None,

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import functools
import json
import os
import weakref
import six
from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
@ -38,10 +38,10 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import model_from_json
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
@ -713,10 +713,20 @@ def serialize_all_attributes(layer, serialization_cache):
if _should_skip_serialization(layer):
return serialized_attr
function_dict = {}
if save_model_default_signature:
# For compatibility with the tf.Lite Converter, the default save signature
# should be traced without nested calls to other wrapped functions.
# TODO(kathywu): Investigate why having nested calls results in a stateful
# function. Perhaps something to do with losses, which are traced in nested
# calls but not in the flat call.
function_dict['_default_save_signature'] = _default_save_signature(layer)
else:
function_dict['_default_save_signature'] = None
object_dict = _wrap_layer_objects(layer, serialization_cache)
try:
function_dict = _wrap_layer_functions(layer, serialization_cache,
save_model_default_signature)
function_dict.update(_wrap_layer_functions(layer, serialization_cache))
except (ValueError, TypeError) as e:
logging.warning('Skipping full serialization of object {}, because an '
'error occurred while tracing layer functions. Error '
@ -799,44 +809,53 @@ def _wrap_layer_objects(layer, serialization_cache):
wrapped_loss_functions))
def _wrap_layer_functions(layer, serialization_cache,
save_model_default_signature=False):
def _wrap_layer_functions(layer, serialization_cache):
"""Returns dict of wrapped layer call function and losses in tf.functions.
Args:
layer: Keras Layer object.
serialization_cache: Dictionary shared between all objects during
serialization.
save_model_default_signature: Whether to save traced model call function.
Returns:
A dictionary containing all keras tf.functions to serialize. See
LayerAttributes and ModelAttributes for the list of all attributes.
"""
# Since Sequential models may be modified in place using model.add() or
# model.pop(), don't use saved functions.
if (isinstance(layer, RevivedLayer) and
not isinstance(layer, RevivedSequential)):
return {fn_name: getattr(layer.keras_api, fn_name, None)
for fn_name in LayerAttributes.all_functions}
# Reset the losses of the layer and its children. The call function in each
# child layer is replaced with tf.functions.
original_attrs = _replace_child_layer_functions(layer, serialization_cache)
original_layer_losses = layer._losses[:] # pylint: disable=protected-access
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = [] # pylint: disable=protected-access
# Note that eager losses do not need to be saved since these functions
# create symbolic losses.
original_fns = _replace_child_layer_functions(layer, serialization_cache)
original_losses = _reset_layer_losses(layer)
# Wrap all the layer call and activity regularizer functions.
call_fn_with_losses = _wrap_call_and_conditional_losses(layer)
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
'__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)}
if save_model_default_signature:
fns['_default_save_signature'] = saving_utils.trace_model_call(layer)
else:
fns['_default_save_signature'] = None
# Use LayerCallCollection to ensure that all layer call functions (__call__,
# call with losses) are traced with the same inputs.
call_collection = LayerCallCollection(layer)
call_fn_with_losses = call_collection.add_function(
_wrap_call_and_conditional_losses(layer),
'{}_layer_call_and_return_conditional_losses'.format(layer.name))
call_fn = call_collection.add_function(
_extract_outputs_from_fn(layer, call_fn_with_losses),
'{}_layer_call_fn'.format(layer.name))
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
'__call__': call_fn}
if layer.activity_regularizer is not None:
fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
fns['call_and_return_all_conditional_losses'] = (
_append_activity_regularizer_loss(
layer, call_fn_with_losses, fns['activity_regularizer_fn']))
call_collection.add_function(
_append_activity_regularizer_loss(call_fn_with_losses,
fns['activity_regularizer_fn']),
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
))
else:
fns['activity_regularizer_fn'] = None
fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
@ -849,14 +868,21 @@ def _wrap_layer_functions(layer, serialization_cache,
if fn is not None and fn.input_signature is not None:
fn.get_concrete_function()
# Restore overwritten functions/losses
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = original_layer_losses # pylint: disable=protected-access
_restore_child_layer_functions(original_attrs)
# Restore overwritten functions and losses
_restore_child_layer_functions(original_fns)
_restore_layer_losses(original_losses)
return fns
def _default_save_signature(layer):
original_losses = _reset_layer_losses(layer)
fn = saving_utils.trace_model_call(layer)
fn.get_concrete_function()
_restore_layer_losses(original_losses)
return fn
def _list_all_layers(obj):
if isinstance(obj, training_lib.Model):
return obj.layers
@ -888,11 +914,9 @@ def _replace_child_layer_functions(layer, serialization_cache):
Child layer 2: ...
}
"""
original_attrs = {}
# pylint: disable=protected-access
original_fns = {}
for child_layer in _list_all_layers(layer):
# Save symbolic layer losses, which will be restored to maintain the same
# state.
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
.functions)
@ -906,27 +930,46 @@ def _replace_child_layer_functions(layer, serialization_cache):
# wrapped. In this case, no replacement is necessary so move on to the
# next child.
continue
original_attrs[child_layer]['call'] = child_layer.call
original_attrs[child_layer]['activity_regularizer'] = (
child_layer.activity_regularizer)
original_fns[child_layer] = {
'call': child_layer.call,
'activity_regularizer': child_layer.activity_regularizer
}
with trackable.no_automatic_dependency_tracking_scope(child_layer):
child_layer.activity_regularizer = layer_fns.get(
'activity_regularizer_fn')
child_layer.call = _use_wrapped_call(
child_layer, layer_fns['call_and_return_conditional_losses'])
child_layer._losses = [] # pylint: disable=protected-access
return original_attrs
return original_fns
# pylint: enable=protected-access
def _restore_child_layer_functions(original_attrs):
def _restore_child_layer_functions(original_fns):
"""Restores attributes replaced with `_replace_child_layer_functions`."""
for child_layer, attrs in original_attrs.items():
for child_layer, fns in original_fns.items():
with trackable.no_automatic_dependency_tracking_scope(child_layer):
child_layer._losses = attrs['losses'] # pylint: disable=protected-access
if 'call' in attrs:
child_layer.call = attrs['call']
child_layer.activity_regularizer = attrs['activity_regularizer']
child_layer.call = fns['call']
child_layer.activity_regularizer = fns['activity_regularizer']
# pylint: disable=protected-access
def _reset_layer_losses(parent_layer):
"""Resets losses of layer and its sublayers, and returns original losses."""
losses_dict = {}
for layer in _list_all_layers(parent_layer) + [parent_layer]:
losses_dict[layer] = {'losses': layer._losses[:],
'eager_losses': layer._eager_losses[:]}
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = []
layer._eager_losses = []
return losses_dict
def _restore_layer_losses(losses_dict):
for layer in losses_dict:
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = losses_dict[layer]['losses']
layer._eager_losses = losses_dict[layer]['eager_losses']
# pylint: enable=protected-access
def _use_wrapped_call(layer, call_fn):
@ -947,8 +990,10 @@ def _use_wrapped_call(layer, call_fn):
training = kwargs.pop('training', None)
if training is None:
training = K.learning_phase()
training = math_ops.cast(training, dtypes.bool)
outputs, losses = call_fn(inputs, training=training)
outputs, losses = tf_utils.smart_cond(
training,
lambda: call_fn(inputs, training=True),
lambda: call_fn(inputs, training=False))
else:
outputs, losses = call_fn(inputs)
layer.add_loss(losses, inputs)
@ -956,6 +1001,128 @@ def _use_wrapped_call(layer, call_fn):
return wrapped_call
class LayerCallCollection(object):
"""Groups wrapped layer call functions.
This is used to ensure that all layer call functions are traced with the same
inputs-
- call
- call_and_return_conditional_losses
- call_and_return_all_conditional_losses
"""
def __init__(self, layer):
self._layer = layer
self._expects_training_arg = layer._expects_training_arg # pylint: disable=protected-access
self._input_signature = self._generate_input_signature(layer)
self._functions = weakref.WeakValueDictionary()
# Bool indicating whether this object is currently tracing the layer call
# functions.
self.tracing = False
def _generate_input_signature(self, layer):
"""Inspects layer object and returns the inferred input signature.
Args:
layer: Layer object.
Returns:
List of possibly nested TensorSpecs of the layer call function inputs.
The list does not contain the `training` argument.
"""
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
return layer.call.input_signature
else:
if isinstance(layer, training_lib.Model):
return saving_utils.model_input_signature(layer)
elif layer.input_spec is not None:
def to_tensor_spec_or_none(x):
spec = input_spec.to_tensor_spec(x, layer.dtype)
# If the shape is too general (e.g. multiple dimensions are allowed),
# return None so that separate functions can be generated for each
# inferred input signature.
# TODO(b/134962016): currently partial signatures are not supported.
if spec.shape == tensor_shape.TensorShape(None):
return None
return spec
input_signature = [nest.map_structure(
to_tensor_spec_or_none, layer.input_spec)]
return input_signature
else:
return None
def add_trace(self, *args, **kwargs):
"""Traces all functions with the same args and kwargs.
Args:
*args: Positional args passed to the original function.
**kwargs: Keyword args passed to the original function.
"""
kwargs = kwargs.copy()
self.tracing = True
for fn in self._functions.values():
# TODO(kathywu): Replace arguments with broader shapes defined in the
# input signature.
if self._expects_training_arg:
kwargs['training'] = False
fn.original_get_concrete_function(*args, **kwargs)
kwargs['training'] = True
fn.original_get_concrete_function(*args, **kwargs)
else:
fn.original_get_concrete_function(*args, **kwargs)
self.tracing = False
@property
def fn_input_signature(self):
"""Returns input signature for the wrapped layer call function."""
if self._expects_training_arg:
# The training arg is left as a python boolean, so the call functions
# will not have an input signature (input signatures may only describe
# tensor arguments).
return None
if None in nest.flatten(self._input_signature):
# TODO(b/134962016): If input signature cannot be partially defined.
return None
return self._input_signature
def add_function(self, python_function, name):
"""Adds a layer call function to the collection."""
self._functions[name] = fn = LayerCall(
self, python_function, name,
input_signature=self.fn_input_signature)
if (None not in nest.flatten(self._input_signature) and
self._expects_training_arg):
# Manually add traces for layers that expect a training argument and have
# a fully defined input signature.
self.add_trace(*self._input_signature)
return fn
class LayerCall(def_function.Function):
"""Function that triggers traces of other functions in the same collection."""
def __init__(self, call_collection, *args, **kwargs):
super(LayerCall, self).__init__(*args, **kwargs)
self.call_collection = call_collection
def __call__(self, *args, **kwargs):
if not self.call_collection.tracing:
self.call_collection.add_trace(*args, **kwargs)
return super(LayerCall, self).__call__(*args, **kwargs)
def get_concrete_function(self, *args, **kwargs):
if not self.call_collection.tracing:
self.call_collection.add_trace(*args, **kwargs)
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
def original_get_concrete_function(self, *args, **kwargs):
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
def _wrap_call_and_conditional_losses(layer):
"""Wraps call function that returns a tuple of (outputs, losses).
@ -966,51 +1133,19 @@ def _wrap_call_and_conditional_losses(layer):
layer: a Keras layer object
Returns:
call function that returns outputs and conditional losses -- excludes
python call function that returns outputs and conditional losses -- excludes
activity regularizer
"""
if isinstance(layer, RevivedLayer):
return layer.keras_api.call_and_return_conditional_losses
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
input_signature = layer.call.input_signature
else:
if (isinstance(layer, training_lib.Model) and
saving_utils.model_input_signature(layer) is not None):
input_signature = saving_utils.model_input_signature(layer)
elif layer.input_spec is not None:
input_signature = [nest.map_structure(
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
layer.input_spec)]
# If input spec is too general, then don't define an input signature.
for spec in nest.flatten(input_signature):
if spec.shape == tensor_shape.TensorShape(None):
input_signature = None
break
else:
input_signature = None
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
input_signature.append(
tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool))
# Create function that generates both outputs and losses
layer_call = layer.call
if layer._expects_training_arg: # pylint: disable=protected-access
def call_and_return_conditional_losses(inputs, training):
_set_symbolic_learning_phase(training)
def call_and_return_conditional_losses(inputs, training=False):
return layer_call(inputs, training=training), layer.get_losses_for(inputs)
else:
def call_and_return_conditional_losses(inputs):
K.set_learning_phase(0)
return layer_call(inputs), layer.get_losses_for(inputs)
return def_function.Function(
call_and_return_conditional_losses,
'{}_layer_call_and_return_conditional_losses'.format(layer.name),
input_signature=input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
return call_and_return_conditional_losses
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
@ -1018,50 +1153,22 @@ def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
if isinstance(layer, RevivedLayer):
return layer.keras_api.__call__ # pylint: disable=protected-access
if layer._expects_training_arg: # pylint: disable=protected-access
def call(inputs, training):
return call_and_return_conditional_losses(inputs, training)[0]
def call(inputs, training=False):
return call_and_return_conditional_losses(inputs, training=training)[0]
else:
def call(inputs):
return call_and_return_conditional_losses(inputs)[0]
return def_function.Function(
call, '{}_layer_call_fn'.format(layer.name),
input_signature=call_and_return_conditional_losses.input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _set_symbolic_learning_phase(value):
"""Set learning phase to a tensor value (for internal use only).
This is used when wrapping call functions as tf.functions that have training
as a tensor input. Thus, when `learning_phase()` is called, the training
tensor is returned. This function is called when saving a model to SavedModel.
Args:
value: A Tensor object.
Raises:
ValueError: If the input value is not a graph tensor
"""
graph = K.get_graph()
if not isinstance(value, ops.Tensor):
raise ValueError('Symbolic learning phase must be a graph tensor.')
K._GRAPH_LEARNING_PHASES[graph] = value # pylint: disable=protected-access
return call
def _append_activity_regularizer_loss(
layer, call_fn_with_losses, activity_regularizer_fn):
call_fn_with_losses, activity_regularizer_fn):
"""Appends activity regularizer loss to losses returned by the wrapped fn."""
def fn(*args):
outputs, losses = call_fn_with_losses(*args)
def fn(*args, **kwargs):
outputs, losses = call_fn_with_losses(*args, **kwargs)
losses.append(activity_regularizer_fn(outputs))
return outputs, losses
return def_function.Function(
fn,
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name),
input_signature=call_fn_with_losses.input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
return fn
def _wrap_unconditional_loss(loss_fn, index):
@ -1135,9 +1242,11 @@ class KerasObjectLoader(load.Loader):
# pylint: disable=protected-access
for node in self._nodes:
if isinstance(node, RevivedModel):
input_signature = (
node.keras_api.call_and_return_conditional_losses.input_signature[0]
)
call_fn = node.keras_api.call_and_return_conditional_losses
if call_fn.input_signature is None:
inputs = infer_inputs_from_restored_call_function(call_fn)
else:
inputs = call_fn.input_signature[0]
if isinstance(node, RevivedSequential):
with trackable.no_automatic_dependency_tracking_scope(node):
node._layers = []
@ -1147,7 +1256,7 @@ class KerasObjectLoader(load.Loader):
if not node.inputs:
# Since this revived object is technically a subclassed model (even if
# the original model is functional/sequential), inputs should be set.
node._set_inputs(input_signature)
node._set_inputs(inputs)
if isinstance(node, RevivedLayer):
losses = node._serialized_attributes.get('regularization_losses', [])
for loss in losses:
@ -1276,6 +1385,26 @@ def recursively_deserialize_keras_object(config, module_objects=None):
raise ValueError('Unable to decode config: {}'.format(config))
def infer_inputs_from_restored_call_function(fn):
"""Returns TensorSpec of inputs from a restored call function.
Args:
fn: Restored layer call function. It is assumed that the inputs are entirely
in the first argument.
Returns:
TensorSpec of call function inputs.
"""
def common_spec(x, y):
return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape),
x.dtype, x.name)
spec = fn.concrete_functions[0].structured_input_signature[0][0]
for concrete in fn.concrete_functions[1:]:
spec2 = concrete.structured_input_signature[0][0]
spec = nest.map_structure(common_spec, spec, spec2)
return spec
class RevivedNetwork(RevivedLayer):
"""Keras network of layers loaded from a SavedModel."""

View File

@ -705,14 +705,8 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
expected_layers = len(model.layers)
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
input_arr = array_ops.ones((4, 3))
training_bool = constant_op.constant(False)
if model._expects_training_arg:
call_args = [input_arr, training_bool]
else:
call_args = [input_arr]
self.assertAllClose(self.evaluate(model(input_arr)),
self.evaluate(loaded(*call_args)))
self.evaluate(loaded(input_arr)))
@keras_parameterized.run_with_all_model_types
def test_compiled_model(self):
@ -765,6 +759,20 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
self.assertEqual('float16', loaded.input_spec['b'].dtype)
def test_multi_input_model(self):
input_1 = keras.layers.Input(shape=(3,))
input_2 = keras.layers.Input(shape=(5,))
model = keras.Model([input_1, input_2], [input_1, input_2])
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
input_arr_1 = np.random.random((1, 3)).astype('float32')
input_arr_2 = np.random.random((1, 5)).astype('float32')
outputs = loaded([input_arr_1, input_arr_2])
self.assertAllEqual(input_arr_1, outputs[0])
self.assertAllEqual(input_arr_2, outputs[1])
if __name__ == '__main__':
test.main()

View File

@ -66,7 +66,7 @@ def model_input_signature(model):
Returns:
A list containing either a single TensorSpec or an object with nested
TensorSpecs.
TensorSpecs. This list does not contain the `training` argument.
"""
try:
inputs = model.inputs

View File

@ -177,11 +177,11 @@ class RestoredFunction(def_function.Function):
# TODO(mdan): We may enable autograph once exceptions are supported.
super(RestoredFunction, self).__init__(
python_function, name, autograph=False)
self._concrete_functions = concrete_functions
self.concrete_functions = concrete_functions
self._function_spec = function_spec
def _list_all_concrete_functions_for_serialization(self):
return self._concrete_functions
return self.concrete_functions
def _defun_with_scope(self, scope):
func = super(RestoredFunction, self)._defun_with_scope(scope)

View File

@ -179,7 +179,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
# Calling get_concrete_function wraps in a second call operation; we want to
# inspect the original function body for the control output; digging into
# graph.as_graph_def() and its FunctionDefLibrary is another option.
imported_concrete, = imported.f._concrete_functions
imported_concrete, = imported.f.concrete_functions
imported_graph = imported_concrete.graph
self.assertIn(
imported_graph.get_operation_by_name("should_be_control_output"),