Merge pull request #29722 from k-w-w/cherrypicks_WPCYV

Cherrypicks for Keras SavedModel
This commit is contained in:
Goldie Gadde 2019-06-12 19:44:07 -07:00 committed by GitHub
commit 0c59cc94fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
356 changed files with 1720 additions and 137 deletions

View File

@ -328,9 +328,9 @@ class FromSavedModelTest(TestModels):
self.assertIn('This converter can only convert a single ConcreteFunction', self.assertIn('This converter can only convert a single ConcreteFunction',
str(error.exception)) str(error.exception))
@test_util.run_v2_only
def testKerasSequentialModel(self): def testKerasSequentialModel(self):
"""Test a simple sequential tf.Keras model.""" """Test a simple sequential tf.Keras model."""
self.skipTest('b/134660903')
input_data = constant_op.constant(1., shape=[1, 1]) input_data = constant_op.constant(1., shape=[1, 1])
x = np.array([[1.], [2.]]) x = np.array([[1.], [2.]])

View File

@ -164,7 +164,7 @@ def _compatible_shapes(flat_relaxed, flat_to_check):
for relaxed, to_check in zip(flat_relaxed, flat_to_check)) for relaxed, to_check in zip(flat_relaxed, flat_to_check))
def _common_shape(x, y): def common_shape(x, y):
"""Find a `TensorShape` that is compatible with both `x` and `y`.""" """Find a `TensorShape` that is compatible with both `x` and `y`."""
if x is None != y is None: if x is None != y is None:
raise RuntimeError( raise RuntimeError(
@ -1577,7 +1577,7 @@ class Function(object):
"relaxed_arg_shapes len: %d vs. %d" "relaxed_arg_shapes len: %d vs. %d"
% (len(arg_shapes), len(relaxed_arg_shapes))) % (len(arg_shapes), len(relaxed_arg_shapes)))
relaxed_arg_shapes = [ relaxed_arg_shapes = [
_common_shape(x, y) for (x, y) in zip( common_shape(x, y) for (x, y) in zip(
arg_shapes, relaxed_arg_shapes)] arg_shapes, relaxed_arg_shapes)]
self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = ( self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = (
relaxed_arg_shapes) relaxed_arg_shapes)
@ -1679,8 +1679,9 @@ def register(func, *args, **kwargs):
def validate_signature(signature): def validate_signature(signature):
if any(not isinstance(arg, tensor_spec.TensorSpec) if any(not isinstance(arg, tensor_spec.TensorSpec)
for arg in nest.flatten(signature, expand_composites=True)): for arg in nest.flatten(signature, expand_composites=True)):
raise TypeError("Invalid input_signature %s; input_signature must be " raise TypeError("Invalid input_signature {}; input_signature must be "
"a possibly nested sequence of TensorSpec objects.") "a possibly nested sequence of TensorSpec objects."
.format(signature))
def defun(func=None, def defun(func=None,

View File

@ -747,6 +747,21 @@ class Layer(module.Module):
"""Optional regularizer function for the output of this layer.""" """Optional regularizer function for the output of this layer."""
self._activity_regularizer = regularizer self._activity_regularizer = regularizer
@property
def input_spec(self):
return self._input_spec
@input_spec.setter
# Must be decorated to prevent tracking, since the input_spec can be nested
# InputSpec objects.
@trackable.no_automatic_dependency_tracking
def input_spec(self, value):
for v in nest.flatten(value):
if v is not None and not isinstance(v, InputSpec):
raise TypeError('Layer input_spec must be an instance of InputSpec. '
'Got: {}'.format(v))
self._input_spec = value
@property @property
def trainable_weights(self): def trainable_weights(self):
if self.trainable: if self.trainable:
@ -2183,8 +2198,10 @@ class Layer(module.Module):
# a NotImplementedError. # a NotImplementedError.
pass pass
if self.input_spec is not None: if self.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = nest.map_structure( metadata['input_spec'] = nest.map_structure(
lambda x: x.get_config(), self.input_spec) lambda x: None if x is None else serialize_keras_object(x),
self.input_spec)
else: else:
metadata['input_spec'] = None metadata['input_spec'] = None
if (self.activity_regularizer is not None and if (self.activity_regularizer is not None and

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import functools import functools
import json import json
import os import os
import weakref
import six import six
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
@ -38,9 +38,10 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import model_from_json from tensorflow.python.keras.saving import model_from_json
from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import mode_keys from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import builder as saved_model_builder
@ -712,10 +713,20 @@ def serialize_all_attributes(layer, serialization_cache):
if _should_skip_serialization(layer): if _should_skip_serialization(layer):
return serialized_attr return serialized_attr
function_dict = {}
if save_model_default_signature:
# For compatibility with the tf.Lite Converter, the default save signature
# should be traced without nested calls to other wrapped functions.
# TODO(kathywu): Investigate why having nested calls results in a stateful
# function. Perhaps something to do with losses, which are traced in nested
# calls but not in the flat call.
function_dict['_default_save_signature'] = _default_save_signature(layer)
else:
function_dict['_default_save_signature'] = None
object_dict = _wrap_layer_objects(layer, serialization_cache) object_dict = _wrap_layer_objects(layer, serialization_cache)
try: try:
function_dict = _wrap_layer_functions(layer, serialization_cache, function_dict.update(_wrap_layer_functions(layer, serialization_cache))
save_model_default_signature)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logging.warning('Skipping full serialization of object {}, because an ' logging.warning('Skipping full serialization of object {}, because an '
'error occurred while tracing layer functions. Error ' 'error occurred while tracing layer functions. Error '
@ -798,44 +809,53 @@ def _wrap_layer_objects(layer, serialization_cache):
wrapped_loss_functions)) wrapped_loss_functions))
def _wrap_layer_functions(layer, serialization_cache, def _wrap_layer_functions(layer, serialization_cache):
save_model_default_signature=False):
"""Returns dict of wrapped layer call function and losses in tf.functions. """Returns dict of wrapped layer call function and losses in tf.functions.
Args: Args:
layer: Keras Layer object. layer: Keras Layer object.
serialization_cache: Dictionary shared between all objects during serialization_cache: Dictionary shared between all objects during
serialization. serialization.
save_model_default_signature: Whether to save traced model call function.
Returns: Returns:
A dictionary containing all keras tf.functions to serialize. See A dictionary containing all keras tf.functions to serialize. See
LayerAttributes and ModelAttributes for the list of all attributes. LayerAttributes and ModelAttributes for the list of all attributes.
""" """
# Since Sequential models may be modified in place using model.add() or
# model.pop(), don't use saved functions.
if (isinstance(layer, RevivedLayer) and
not isinstance(layer, RevivedSequential)):
return {fn_name: getattr(layer.keras_api, fn_name, None)
for fn_name in LayerAttributes.all_functions}
# Reset the losses of the layer and its children. The call function in each # Reset the losses of the layer and its children. The call function in each
# child layer is replaced with tf.functions. # child layer is replaced with tf.functions.
original_attrs = _replace_child_layer_functions(layer, serialization_cache) original_fns = _replace_child_layer_functions(layer, serialization_cache)
original_layer_losses = layer._losses[:] # pylint: disable=protected-access original_losses = _reset_layer_losses(layer)
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = [] # pylint: disable=protected-access
# Note that eager losses do not need to be saved since these functions
# create symbolic losses.
# Wrap all the layer call and activity regularizer functions. # Wrap all the layer call and activity regularizer functions.
call_fn_with_losses = _wrap_call_and_conditional_losses(layer)
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
'__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)}
if save_model_default_signature: # Use LayerCallCollection to ensure that all layer call functions (__call__,
fns['_default_save_signature'] = saving_utils.trace_model_call(layer) # call with losses) are traced with the same inputs.
else: call_collection = LayerCallCollection(layer)
fns['_default_save_signature'] = None call_fn_with_losses = call_collection.add_function(
_wrap_call_and_conditional_losses(layer),
'{}_layer_call_and_return_conditional_losses'.format(layer.name))
call_fn = call_collection.add_function(
_extract_outputs_from_fn(layer, call_fn_with_losses),
'{}_layer_call_fn'.format(layer.name))
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
'__call__': call_fn}
if layer.activity_regularizer is not None: if layer.activity_regularizer is not None:
fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
fns['call_and_return_all_conditional_losses'] = ( fns['call_and_return_all_conditional_losses'] = (
_append_activity_regularizer_loss( call_collection.add_function(
layer, call_fn_with_losses, fns['activity_regularizer_fn'])) _append_activity_regularizer_loss(call_fn_with_losses,
fns['activity_regularizer_fn']),
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
))
else: else:
fns['activity_regularizer_fn'] = None fns['activity_regularizer_fn'] = None
fns['call_and_return_all_conditional_losses'] = call_fn_with_losses fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
@ -848,14 +868,21 @@ def _wrap_layer_functions(layer, serialization_cache,
if fn is not None and fn.input_signature is not None: if fn is not None and fn.input_signature is not None:
fn.get_concrete_function() fn.get_concrete_function()
# Restore overwritten functions/losses # Restore overwritten functions and losses
with trackable.no_automatic_dependency_tracking_scope(layer): _restore_child_layer_functions(original_fns)
layer._losses = original_layer_losses # pylint: disable=protected-access _restore_layer_losses(original_losses)
_restore_child_layer_functions(original_attrs)
return fns return fns
def _default_save_signature(layer):
original_losses = _reset_layer_losses(layer)
fn = saving_utils.trace_model_call(layer)
fn.get_concrete_function()
_restore_layer_losses(original_losses)
return fn
def _list_all_layers(obj): def _list_all_layers(obj):
if isinstance(obj, training_lib.Model): if isinstance(obj, training_lib.Model):
return obj.layers return obj.layers
@ -887,11 +914,9 @@ def _replace_child_layer_functions(layer, serialization_cache):
Child layer 2: ... Child layer 2: ...
} }
""" """
original_attrs = {} # pylint: disable=protected-access
original_fns = {}
for child_layer in _list_all_layers(layer): for child_layer in _list_all_layers(layer):
# Save symbolic layer losses, which will be restored to maintain the same
# state.
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]: if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
layer_fns = (serialize_all_attributes(child_layer, serialization_cache) layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
.functions) .functions)
@ -905,27 +930,46 @@ def _replace_child_layer_functions(layer, serialization_cache):
# wrapped. In this case, no replacement is necessary so move on to the # wrapped. In this case, no replacement is necessary so move on to the
# next child. # next child.
continue continue
original_fns[child_layer] = {
original_attrs[child_layer]['call'] = child_layer.call 'call': child_layer.call,
original_attrs[child_layer]['activity_regularizer'] = ( 'activity_regularizer': child_layer.activity_regularizer
child_layer.activity_regularizer) }
with trackable.no_automatic_dependency_tracking_scope(child_layer): with trackable.no_automatic_dependency_tracking_scope(child_layer):
child_layer.activity_regularizer = layer_fns.get( child_layer.activity_regularizer = layer_fns.get(
'activity_regularizer_fn') 'activity_regularizer_fn')
child_layer.call = _use_wrapped_call( child_layer.call = _use_wrapped_call(
child_layer, layer_fns['call_and_return_conditional_losses']) child_layer, layer_fns['call_and_return_conditional_losses'])
child_layer._losses = [] # pylint: disable=protected-access return original_fns
return original_attrs # pylint: enable=protected-access
def _restore_child_layer_functions(original_attrs): def _restore_child_layer_functions(original_fns):
"""Restores attributes replaced with `_replace_child_layer_functions`.""" """Restores attributes replaced with `_replace_child_layer_functions`."""
for child_layer, attrs in original_attrs.items(): for child_layer, fns in original_fns.items():
with trackable.no_automatic_dependency_tracking_scope(child_layer): with trackable.no_automatic_dependency_tracking_scope(child_layer):
child_layer._losses = attrs['losses'] # pylint: disable=protected-access child_layer.call = fns['call']
if 'call' in attrs: child_layer.activity_regularizer = fns['activity_regularizer']
child_layer.call = attrs['call']
child_layer.activity_regularizer = attrs['activity_regularizer']
# pylint: disable=protected-access
def _reset_layer_losses(parent_layer):
"""Resets losses of layer and its sublayers, and returns original losses."""
losses_dict = {}
for layer in _list_all_layers(parent_layer) + [parent_layer]:
losses_dict[layer] = {'losses': layer._losses[:],
'eager_losses': layer._eager_losses[:]}
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = []
layer._eager_losses = []
return losses_dict
def _restore_layer_losses(losses_dict):
for layer in losses_dict:
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = losses_dict[layer]['losses']
layer._eager_losses = losses_dict[layer]['eager_losses']
# pylint: enable=protected-access
def _use_wrapped_call(layer, call_fn): def _use_wrapped_call(layer, call_fn):
@ -946,8 +990,10 @@ def _use_wrapped_call(layer, call_fn):
training = kwargs.pop('training', None) training = kwargs.pop('training', None)
if training is None: if training is None:
training = K.learning_phase() training = K.learning_phase()
training = math_ops.cast(training, dtypes.bool) outputs, losses = tf_utils.smart_cond(
outputs, losses = call_fn(inputs, training=training) training,
lambda: call_fn(inputs, training=True),
lambda: call_fn(inputs, training=False))
else: else:
outputs, losses = call_fn(inputs) outputs, losses = call_fn(inputs)
layer.add_loss(losses, inputs) layer.add_loss(losses, inputs)
@ -955,6 +1001,128 @@ def _use_wrapped_call(layer, call_fn):
return wrapped_call return wrapped_call
class LayerCallCollection(object):
"""Groups wrapped layer call functions.
This is used to ensure that all layer call functions are traced with the same
inputs-
- call
- call_and_return_conditional_losses
- call_and_return_all_conditional_losses
"""
def __init__(self, layer):
self._layer = layer
self._expects_training_arg = layer._expects_training_arg # pylint: disable=protected-access
self._input_signature = self._generate_input_signature(layer)
self._functions = weakref.WeakValueDictionary()
# Bool indicating whether this object is currently tracing the layer call
# functions.
self.tracing = False
def _generate_input_signature(self, layer):
"""Inspects layer object and returns the inferred input signature.
Args:
layer: Layer object.
Returns:
List of possibly nested TensorSpecs of the layer call function inputs.
The list does not contain the `training` argument.
"""
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
return layer.call.input_signature
else:
if isinstance(layer, training_lib.Model):
return saving_utils.model_input_signature(layer)
elif layer.input_spec is not None:
def to_tensor_spec_or_none(x):
spec = input_spec.to_tensor_spec(x, layer.dtype)
# If the shape is too general (e.g. multiple dimensions are allowed),
# return None so that separate functions can be generated for each
# inferred input signature.
# TODO(b/134962016): currently partial signatures are not supported.
if spec.shape == tensor_shape.TensorShape(None):
return None
return spec
input_signature = [nest.map_structure(
to_tensor_spec_or_none, layer.input_spec)]
return input_signature
else:
return None
def add_trace(self, *args, **kwargs):
"""Traces all functions with the same args and kwargs.
Args:
*args: Positional args passed to the original function.
**kwargs: Keyword args passed to the original function.
"""
kwargs = kwargs.copy()
self.tracing = True
for fn in self._functions.values():
# TODO(kathywu): Replace arguments with broader shapes defined in the
# input signature.
if self._expects_training_arg:
kwargs['training'] = False
fn.original_get_concrete_function(*args, **kwargs)
kwargs['training'] = True
fn.original_get_concrete_function(*args, **kwargs)
else:
fn.original_get_concrete_function(*args, **kwargs)
self.tracing = False
@property
def fn_input_signature(self):
"""Returns input signature for the wrapped layer call function."""
if self._expects_training_arg:
# The training arg is left as a python boolean, so the call functions
# will not have an input signature (input signatures may only describe
# tensor arguments).
return None
if None in nest.flatten(self._input_signature):
# TODO(b/134962016): If input signature cannot be partially defined.
return None
return self._input_signature
def add_function(self, python_function, name):
"""Adds a layer call function to the collection."""
self._functions[name] = fn = LayerCall(
self, python_function, name,
input_signature=self.fn_input_signature)
if (None not in nest.flatten(self._input_signature) and
self._expects_training_arg):
# Manually add traces for layers that expect a training argument and have
# a fully defined input signature.
self.add_trace(*self._input_signature)
return fn
class LayerCall(def_function.Function):
"""Function that triggers traces of other functions in the same collection."""
def __init__(self, call_collection, *args, **kwargs):
super(LayerCall, self).__init__(*args, **kwargs)
self.call_collection = call_collection
def __call__(self, *args, **kwargs):
if not self.call_collection.tracing:
self.call_collection.add_trace(*args, **kwargs)
return super(LayerCall, self).__call__(*args, **kwargs)
def get_concrete_function(self, *args, **kwargs):
if not self.call_collection.tracing:
self.call_collection.add_trace(*args, **kwargs)
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
def original_get_concrete_function(self, *args, **kwargs):
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
def _wrap_call_and_conditional_losses(layer): def _wrap_call_and_conditional_losses(layer):
"""Wraps call function that returns a tuple of (outputs, losses). """Wraps call function that returns a tuple of (outputs, losses).
@ -965,51 +1133,19 @@ def _wrap_call_and_conditional_losses(layer):
layer: a Keras layer object layer: a Keras layer object
Returns: Returns:
call function that returns outputs and conditional losses -- excludes python call function that returns outputs and conditional losses -- excludes
activity regularizer activity regularizer
""" """
if isinstance(layer, RevivedLayer):
return layer.keras_api.call_and_return_conditional_losses
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
input_signature = layer.call.input_signature
else:
if (isinstance(layer, training_lib.Model) and
saving_utils.model_input_signature(layer) is not None):
input_signature = saving_utils.model_input_signature(layer)
elif layer.input_spec is not None:
input_signature = [nest.map_structure(
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
layer.input_spec)]
# If input spec is too general, then don't define an input signature.
for spec in nest.flatten(input_signature):
if spec.shape == tensor_shape.TensorShape(None):
input_signature = None
break
else:
input_signature = None
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
input_signature.append(
tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool))
# Create function that generates both outputs and losses # Create function that generates both outputs and losses
layer_call = layer.call layer_call = layer.call
if layer._expects_training_arg: # pylint: disable=protected-access if layer._expects_training_arg: # pylint: disable=protected-access
def call_and_return_conditional_losses(inputs, training): def call_and_return_conditional_losses(inputs, training=False):
_set_symbolic_learning_phase(training)
return layer_call(inputs, training=training), layer.get_losses_for(inputs) return layer_call(inputs, training=training), layer.get_losses_for(inputs)
else: else:
def call_and_return_conditional_losses(inputs): def call_and_return_conditional_losses(inputs):
K.set_learning_phase(0) K.set_learning_phase(0)
return layer_call(inputs), layer.get_losses_for(inputs) return layer_call(inputs), layer.get_losses_for(inputs)
return def_function.Function( return call_and_return_conditional_losses
call_and_return_conditional_losses,
'{}_layer_call_and_return_conditional_losses'.format(layer.name),
input_signature=input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
@ -1017,50 +1153,22 @@ def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
if isinstance(layer, RevivedLayer): if isinstance(layer, RevivedLayer):
return layer.keras_api.__call__ # pylint: disable=protected-access return layer.keras_api.__call__ # pylint: disable=protected-access
if layer._expects_training_arg: # pylint: disable=protected-access if layer._expects_training_arg: # pylint: disable=protected-access
def call(inputs, training): def call(inputs, training=False):
return call_and_return_conditional_losses(inputs, training)[0] return call_and_return_conditional_losses(inputs, training=training)[0]
else: else:
def call(inputs): def call(inputs):
return call_and_return_conditional_losses(inputs)[0] return call_and_return_conditional_losses(inputs)[0]
return def_function.Function( return call
call, '{}_layer_call_fn'.format(layer.name),
input_signature=call_and_return_conditional_losses.input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _set_symbolic_learning_phase(value):
"""Set learning phase to a tensor value (for internal use only).
This is used when wrapping call functions as tf.functions that have training
as a tensor input. Thus, when `learning_phase()` is called, the training
tensor is returned. This function is called when saving a model to SavedModel.
Args:
value: A Tensor object.
Raises:
ValueError: If the input value is not a graph tensor
"""
graph = K.get_graph()
if not isinstance(value, ops.Tensor):
raise ValueError('Symbolic learning phase must be a graph tensor.')
K._GRAPH_LEARNING_PHASES[graph] = value # pylint: disable=protected-access
def _append_activity_regularizer_loss( def _append_activity_regularizer_loss(
layer, call_fn_with_losses, activity_regularizer_fn): call_fn_with_losses, activity_regularizer_fn):
"""Appends activity regularizer loss to losses returned by the wrapped fn.""" """Appends activity regularizer loss to losses returned by the wrapped fn."""
def fn(*args): def fn(*args, **kwargs):
outputs, losses = call_fn_with_losses(*args) outputs, losses = call_fn_with_losses(*args, **kwargs)
losses.append(activity_regularizer_fn(outputs)) losses.append(activity_regularizer_fn(outputs))
return outputs, losses return outputs, losses
return def_function.Function( return fn
fn,
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name),
input_signature=call_fn_with_losses.input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _wrap_unconditional_loss(loss_fn, index): def _wrap_unconditional_loss(loss_fn, index):
@ -1134,9 +1242,11 @@ class KerasObjectLoader(load.Loader):
# pylint: disable=protected-access # pylint: disable=protected-access
for node in self._nodes: for node in self._nodes:
if isinstance(node, RevivedModel): if isinstance(node, RevivedModel):
input_signature = ( call_fn = node.keras_api.call_and_return_conditional_losses
node.keras_api.call_and_return_conditional_losses.input_signature[0] if call_fn.input_signature is None:
) inputs = infer_inputs_from_restored_call_function(call_fn)
else:
inputs = call_fn.input_signature[0]
if isinstance(node, RevivedSequential): if isinstance(node, RevivedSequential):
with trackable.no_automatic_dependency_tracking_scope(node): with trackable.no_automatic_dependency_tracking_scope(node):
node._layers = [] node._layers = []
@ -1146,7 +1256,7 @@ class KerasObjectLoader(load.Loader):
if not node.inputs: if not node.inputs:
# Since this revived object is technically a subclassed model (even if # Since this revived object is technically a subclassed model (even if
# the original model is functional/sequential), inputs should be set. # the original model is functional/sequential), inputs should be set.
node._set_inputs(input_signature) node._set_inputs(inputs)
if isinstance(node, RevivedLayer): if isinstance(node, RevivedLayer):
losses = node._serialized_attributes.get('regularization_losses', []) losses = node._serialized_attributes.get('regularization_losses', [])
for loss in losses: for loss in losses:
@ -1218,8 +1328,9 @@ class RevivedLayer(object):
if metadata.get('config') is not None: if metadata.get('config') is not None:
revived_obj._config = metadata['config'] revived_obj._config = metadata['config']
if metadata.get('input_spec') is not None: if metadata.get('input_spec') is not None:
revived_obj.input_spec = input_spec.InputSpec.from_config( revived_obj.input_spec = recursively_deserialize_keras_object(
metadata['input_spec']) metadata['input_spec'],
module_objects={'InputSpec': input_spec.InputSpec})
if metadata.get('activity_regularizer') is not None: if metadata.get('activity_regularizer') is not None:
revived_obj.activity_regularizer = regularizers.deserialize( revived_obj.activity_regularizer = regularizers.deserialize(
metadata['activity_regularizer']) metadata['activity_regularizer'])
@ -1258,6 +1369,42 @@ class RevivedLayer(object):
return call_fn(inputs, *args, **kwargs) return call_fn(inputs, *args, **kwargs)
def recursively_deserialize_keras_object(config, module_objects=None):
"""Deserialize Keras object from a nested structure."""
if isinstance(config, dict):
if 'class_name' in config:
return deserialize_keras_object(config, module_objects=module_objects)
else:
return {key: recursively_deserialize_keras_object(config[key],
module_objects)
for key in config}
if isinstance(config, (tuple, list)):
return [recursively_deserialize_keras_object(x, module_objects)
for x in config]
else:
raise ValueError('Unable to decode config: {}'.format(config))
def infer_inputs_from_restored_call_function(fn):
"""Returns TensorSpec of inputs from a restored call function.
Args:
fn: Restored layer call function. It is assumed that the inputs are entirely
in the first argument.
Returns:
TensorSpec of call function inputs.
"""
def common_spec(x, y):
return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape),
x.dtype, x.name)
spec = fn.concrete_functions[0].structured_input_signature[0][0]
for concrete in fn.concrete_functions[1:]:
spec2 = concrete.structured_input_signature[0][0]
spec = nest.map_structure(common_spec, spec, spec2)
return spec
class RevivedNetwork(RevivedLayer): class RevivedNetwork(RevivedLayer):
"""Keras network of layers loaded from a SavedModel.""" """Keras network of layers loaded from a SavedModel."""

View File

@ -705,14 +705,8 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
expected_layers = len(model.layers) expected_layers = len(model.layers)
self.assertEqual(expected_layers, len(loaded.keras_api.layers)) self.assertEqual(expected_layers, len(loaded.keras_api.layers))
input_arr = array_ops.ones((4, 3)) input_arr = array_ops.ones((4, 3))
training_bool = constant_op.constant(False)
if model._expects_training_arg:
call_args = [input_arr, training_bool]
else:
call_args = [input_arr]
self.assertAllClose(self.evaluate(model(input_arr)), self.assertAllClose(self.evaluate(model(input_arr)),
self.evaluate(loaded(*call_args))) self.evaluate(loaded(input_arr)))
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
def test_compiled_model(self): def test_compiled_model(self):
@ -747,6 +741,38 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
model.load_weights(ckpt_path) model.load_weights(ckpt_path)
self.assertAllClose(predict, model.predict(input_arr)) self.assertAllClose(predict, model.predict(input_arr))
def test_metadata_input_spec(self):
class LayerWithNestedSpec(keras.layers.Layer):
def __init__(self):
super(LayerWithNestedSpec, self).__init__()
self.input_spec = {
'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}),
'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')}
layer = LayerWithNestedSpec()
saved_model_dir = self._save_model_dir()
tf_save.save(layer, saved_model_dir)
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
self.assertEqual(3, loaded.input_spec['a'].max_ndim)
self.assertEqual({-1: 2}, loaded.input_spec['a'].axes)
self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
self.assertEqual('float16', loaded.input_spec['b'].dtype)
def test_multi_input_model(self):
input_1 = keras.layers.Input(shape=(3,))
input_2 = keras.layers.Input(shape=(5,))
model = keras.Model([input_1, input_2], [input_1, input_2])
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
input_arr_1 = np.random.random((1, 3)).astype('float32')
input_arr_2 = np.random.random((1, 5)).astype('float32')
outputs = loaded([input_arr_1, input_arr_2])
self.assertAllEqual(input_arr_1, outputs[0])
self.assertAllEqual(input_arr_2, outputs[1])
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -66,7 +66,7 @@ def model_input_signature(model):
Returns: Returns:
A list containing either a single TensorSpec or an object with nested A list containing either a single TensorSpec or an object with nested
TensorSpecs. TensorSpecs. This list does not contain the `training` argument.
""" """
try: try:
inputs = model.inputs inputs = model.inputs

View File

@ -177,11 +177,11 @@ class RestoredFunction(def_function.Function):
# TODO(mdan): We may enable autograph once exceptions are supported. # TODO(mdan): We may enable autograph once exceptions are supported.
super(RestoredFunction, self).__init__( super(RestoredFunction, self).__init__(
python_function, name, autograph=False) python_function, name, autograph=False)
self._concrete_functions = concrete_functions self.concrete_functions = concrete_functions
self._function_spec = function_spec self._function_spec = function_spec
def _list_all_concrete_functions_for_serialization(self): def _list_all_concrete_functions_for_serialization(self):
return self._concrete_functions return self.concrete_functions
def _defun_with_scope(self, scope): def _defun_with_scope(self, scope):
func = super(RestoredFunction, self)._defun_with_scope(scope) func = super(RestoredFunction, self)._defun_with_scope(scope)

View File

@ -179,7 +179,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
# Calling get_concrete_function wraps in a second call operation; we want to # Calling get_concrete_function wraps in a second call operation; we want to
# inspect the original function body for the control output; digging into # inspect the original function body for the control output; digging into
# graph.as_graph_def() and its FunctionDefLibrary is another option. # graph.as_graph_def() and its FunctionDefLibrary is another option.
imported_concrete, = imported.f._concrete_functions imported_concrete, = imported.f.concrete_functions
imported_graph = imported_concrete.graph imported_graph = imported_concrete.graph
self.assertIn( self.assertIn(
imported_graph.get_operation_by_name("should_be_control_output"), imported_graph.get_operation_by_name("should_be_control_output"),

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -39,6 +39,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -68,6 +68,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "kernel_constraint" name: "kernel_constraint"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -40,6 +40,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -40,6 +40,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "kernel_constraint" name: "kernel_constraint"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "kernel_constraint" name: "kernel_constraint"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -33,6 +33,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -36,6 +36,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -55,6 +55,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "kernel_constraint" name: "kernel_constraint"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -34,6 +34,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -35,6 +35,10 @@ tf_class {
name: "input_shape" name: "input_shape"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member { member {
name: "losses" name: "losses"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

Some files were not shown because too many files have changed in this diff Show More