|
|
|
@ -20,12 +20,12 @@ from __future__ import print_function
|
|
|
|
|
import functools
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import weakref
|
|
|
|
|
import six
|
|
|
|
|
|
|
|
|
|
from tensorflow.python.client import session
|
|
|
|
|
from tensorflow.python.eager import def_function
|
|
|
|
|
from tensorflow.python.eager import function as defun
|
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
|
|
|
from tensorflow.python.framework import ops
|
|
|
|
|
from tensorflow.python.framework import tensor_shape
|
|
|
|
|
from tensorflow.python.framework import tensor_spec
|
|
|
|
@ -38,10 +38,10 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
|
|
|
|
from tensorflow.python.keras.saving import model_from_json
|
|
|
|
|
from tensorflow.python.keras.saving import saving_utils
|
|
|
|
|
from tensorflow.python.keras.utils import mode_keys
|
|
|
|
|
from tensorflow.python.keras.utils import tf_utils
|
|
|
|
|
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
|
|
|
|
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
|
|
|
|
from tensorflow.python.lib.io import file_io
|
|
|
|
|
from tensorflow.python.ops import math_ops
|
|
|
|
|
from tensorflow.python.ops import variables
|
|
|
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
|
|
|
from tensorflow.python.saved_model import builder as saved_model_builder
|
|
|
|
@ -713,10 +713,20 @@ def serialize_all_attributes(layer, serialization_cache):
|
|
|
|
|
if _should_skip_serialization(layer):
|
|
|
|
|
return serialized_attr
|
|
|
|
|
|
|
|
|
|
function_dict = {}
|
|
|
|
|
if save_model_default_signature:
|
|
|
|
|
# For compatibility with the tf.Lite Converter, the default save signature
|
|
|
|
|
# should be traced without nested calls to other wrapped functions.
|
|
|
|
|
# TODO(kathywu): Investigate why having nested calls results in a stateful
|
|
|
|
|
# function. Perhaps something to do with losses, which are traced in nested
|
|
|
|
|
# calls but not in the flat call.
|
|
|
|
|
function_dict['_default_save_signature'] = _default_save_signature(layer)
|
|
|
|
|
else:
|
|
|
|
|
function_dict['_default_save_signature'] = None
|
|
|
|
|
|
|
|
|
|
object_dict = _wrap_layer_objects(layer, serialization_cache)
|
|
|
|
|
try:
|
|
|
|
|
function_dict = _wrap_layer_functions(layer, serialization_cache,
|
|
|
|
|
save_model_default_signature)
|
|
|
|
|
function_dict.update(_wrap_layer_functions(layer, serialization_cache))
|
|
|
|
|
except (ValueError, TypeError) as e:
|
|
|
|
|
logging.warning('Skipping full serialization of object {}, because an '
|
|
|
|
|
'error occurred while tracing layer functions. Error '
|
|
|
|
@ -799,44 +809,53 @@ def _wrap_layer_objects(layer, serialization_cache):
|
|
|
|
|
wrapped_loss_functions))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_layer_functions(layer, serialization_cache,
|
|
|
|
|
save_model_default_signature=False):
|
|
|
|
|
def _wrap_layer_functions(layer, serialization_cache):
|
|
|
|
|
"""Returns dict of wrapped layer call function and losses in tf.functions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
layer: Keras Layer object.
|
|
|
|
|
serialization_cache: Dictionary shared between all objects during
|
|
|
|
|
serialization.
|
|
|
|
|
save_model_default_signature: Whether to save traced model call function.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A dictionary containing all keras tf.functions to serialize. See
|
|
|
|
|
LayerAttributes and ModelAttributes for the list of all attributes.
|
|
|
|
|
"""
|
|
|
|
|
# Since Sequential models may be modified in place using model.add() or
|
|
|
|
|
# model.pop(), don't use saved functions.
|
|
|
|
|
if (isinstance(layer, RevivedLayer) and
|
|
|
|
|
not isinstance(layer, RevivedSequential)):
|
|
|
|
|
return {fn_name: getattr(layer.keras_api, fn_name, None)
|
|
|
|
|
for fn_name in LayerAttributes.all_functions}
|
|
|
|
|
|
|
|
|
|
# Reset the losses of the layer and its children. The call function in each
|
|
|
|
|
# child layer is replaced with tf.functions.
|
|
|
|
|
original_attrs = _replace_child_layer_functions(layer, serialization_cache)
|
|
|
|
|
original_layer_losses = layer._losses[:] # pylint: disable=protected-access
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(layer):
|
|
|
|
|
layer._losses = [] # pylint: disable=protected-access
|
|
|
|
|
# Note that eager losses do not need to be saved since these functions
|
|
|
|
|
# create symbolic losses.
|
|
|
|
|
original_fns = _replace_child_layer_functions(layer, serialization_cache)
|
|
|
|
|
original_losses = _reset_layer_losses(layer)
|
|
|
|
|
|
|
|
|
|
# Wrap all the layer call and activity regularizer functions.
|
|
|
|
|
call_fn_with_losses = _wrap_call_and_conditional_losses(layer)
|
|
|
|
|
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
|
|
|
|
|
'__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)}
|
|
|
|
|
|
|
|
|
|
if save_model_default_signature:
|
|
|
|
|
fns['_default_save_signature'] = saving_utils.trace_model_call(layer)
|
|
|
|
|
else:
|
|
|
|
|
fns['_default_save_signature'] = None
|
|
|
|
|
# Use LayerCallCollection to ensure that all layer call functions (__call__,
|
|
|
|
|
# call with losses) are traced with the same inputs.
|
|
|
|
|
call_collection = LayerCallCollection(layer)
|
|
|
|
|
call_fn_with_losses = call_collection.add_function(
|
|
|
|
|
_wrap_call_and_conditional_losses(layer),
|
|
|
|
|
'{}_layer_call_and_return_conditional_losses'.format(layer.name))
|
|
|
|
|
call_fn = call_collection.add_function(
|
|
|
|
|
_extract_outputs_from_fn(layer, call_fn_with_losses),
|
|
|
|
|
'{}_layer_call_fn'.format(layer.name))
|
|
|
|
|
|
|
|
|
|
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
|
|
|
|
|
'__call__': call_fn}
|
|
|
|
|
|
|
|
|
|
if layer.activity_regularizer is not None:
|
|
|
|
|
fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
|
|
|
|
|
fns['call_and_return_all_conditional_losses'] = (
|
|
|
|
|
_append_activity_regularizer_loss(
|
|
|
|
|
layer, call_fn_with_losses, fns['activity_regularizer_fn']))
|
|
|
|
|
call_collection.add_function(
|
|
|
|
|
_append_activity_regularizer_loss(call_fn_with_losses,
|
|
|
|
|
fns['activity_regularizer_fn']),
|
|
|
|
|
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
|
|
|
|
|
))
|
|
|
|
|
else:
|
|
|
|
|
fns['activity_regularizer_fn'] = None
|
|
|
|
|
fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
|
|
|
|
@ -849,14 +868,21 @@ def _wrap_layer_functions(layer, serialization_cache,
|
|
|
|
|
if fn is not None and fn.input_signature is not None:
|
|
|
|
|
fn.get_concrete_function()
|
|
|
|
|
|
|
|
|
|
# Restore overwritten functions/losses
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(layer):
|
|
|
|
|
layer._losses = original_layer_losses # pylint: disable=protected-access
|
|
|
|
|
_restore_child_layer_functions(original_attrs)
|
|
|
|
|
# Restore overwritten functions and losses
|
|
|
|
|
_restore_child_layer_functions(original_fns)
|
|
|
|
|
_restore_layer_losses(original_losses)
|
|
|
|
|
|
|
|
|
|
return fns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _default_save_signature(layer):
|
|
|
|
|
original_losses = _reset_layer_losses(layer)
|
|
|
|
|
fn = saving_utils.trace_model_call(layer)
|
|
|
|
|
fn.get_concrete_function()
|
|
|
|
|
_restore_layer_losses(original_losses)
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _list_all_layers(obj):
|
|
|
|
|
if isinstance(obj, training_lib.Model):
|
|
|
|
|
return obj.layers
|
|
|
|
@ -888,11 +914,9 @@ def _replace_child_layer_functions(layer, serialization_cache):
|
|
|
|
|
Child layer 2: ...
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
original_attrs = {}
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
original_fns = {}
|
|
|
|
|
for child_layer in _list_all_layers(layer):
|
|
|
|
|
# Save symbolic layer losses, which will be restored to maintain the same
|
|
|
|
|
# state.
|
|
|
|
|
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
|
|
|
|
|
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
|
|
|
|
|
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
|
|
|
|
|
.functions)
|
|
|
|
@ -906,27 +930,46 @@ def _replace_child_layer_functions(layer, serialization_cache):
|
|
|
|
|
# wrapped. In this case, no replacement is necessary so move on to the
|
|
|
|
|
# next child.
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
original_attrs[child_layer]['call'] = child_layer.call
|
|
|
|
|
original_attrs[child_layer]['activity_regularizer'] = (
|
|
|
|
|
child_layer.activity_regularizer)
|
|
|
|
|
original_fns[child_layer] = {
|
|
|
|
|
'call': child_layer.call,
|
|
|
|
|
'activity_regularizer': child_layer.activity_regularizer
|
|
|
|
|
}
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(child_layer):
|
|
|
|
|
child_layer.activity_regularizer = layer_fns.get(
|
|
|
|
|
'activity_regularizer_fn')
|
|
|
|
|
child_layer.call = _use_wrapped_call(
|
|
|
|
|
child_layer, layer_fns['call_and_return_conditional_losses'])
|
|
|
|
|
child_layer._losses = [] # pylint: disable=protected-access
|
|
|
|
|
return original_attrs
|
|
|
|
|
return original_fns
|
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _restore_child_layer_functions(original_attrs):
|
|
|
|
|
def _restore_child_layer_functions(original_fns):
|
|
|
|
|
"""Restores attributes replaced with `_replace_child_layer_functions`."""
|
|
|
|
|
for child_layer, attrs in original_attrs.items():
|
|
|
|
|
for child_layer, fns in original_fns.items():
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(child_layer):
|
|
|
|
|
child_layer._losses = attrs['losses'] # pylint: disable=protected-access
|
|
|
|
|
if 'call' in attrs:
|
|
|
|
|
child_layer.call = attrs['call']
|
|
|
|
|
child_layer.activity_regularizer = attrs['activity_regularizer']
|
|
|
|
|
child_layer.call = fns['call']
|
|
|
|
|
child_layer.activity_regularizer = fns['activity_regularizer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
def _reset_layer_losses(parent_layer):
|
|
|
|
|
"""Resets losses of layer and its sublayers, and returns original losses."""
|
|
|
|
|
losses_dict = {}
|
|
|
|
|
for layer in _list_all_layers(parent_layer) + [parent_layer]:
|
|
|
|
|
losses_dict[layer] = {'losses': layer._losses[:],
|
|
|
|
|
'eager_losses': layer._eager_losses[:]}
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(layer):
|
|
|
|
|
layer._losses = []
|
|
|
|
|
layer._eager_losses = []
|
|
|
|
|
return losses_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _restore_layer_losses(losses_dict):
|
|
|
|
|
for layer in losses_dict:
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(layer):
|
|
|
|
|
layer._losses = losses_dict[layer]['losses']
|
|
|
|
|
layer._eager_losses = losses_dict[layer]['eager_losses']
|
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _use_wrapped_call(layer, call_fn):
|
|
|
|
@ -947,8 +990,10 @@ def _use_wrapped_call(layer, call_fn):
|
|
|
|
|
training = kwargs.pop('training', None)
|
|
|
|
|
if training is None:
|
|
|
|
|
training = K.learning_phase()
|
|
|
|
|
training = math_ops.cast(training, dtypes.bool)
|
|
|
|
|
outputs, losses = call_fn(inputs, training=training)
|
|
|
|
|
outputs, losses = tf_utils.smart_cond(
|
|
|
|
|
training,
|
|
|
|
|
lambda: call_fn(inputs, training=True),
|
|
|
|
|
lambda: call_fn(inputs, training=False))
|
|
|
|
|
else:
|
|
|
|
|
outputs, losses = call_fn(inputs)
|
|
|
|
|
layer.add_loss(losses, inputs)
|
|
|
|
@ -956,6 +1001,128 @@ def _use_wrapped_call(layer, call_fn):
|
|
|
|
|
return wrapped_call
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerCallCollection(object):
|
|
|
|
|
"""Groups wrapped layer call functions.
|
|
|
|
|
|
|
|
|
|
This is used to ensure that all layer call functions are traced with the same
|
|
|
|
|
inputs-
|
|
|
|
|
- call
|
|
|
|
|
- call_and_return_conditional_losses
|
|
|
|
|
- call_and_return_all_conditional_losses
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
self._layer = layer
|
|
|
|
|
self._expects_training_arg = layer._expects_training_arg # pylint: disable=protected-access
|
|
|
|
|
self._input_signature = self._generate_input_signature(layer)
|
|
|
|
|
self._functions = weakref.WeakValueDictionary()
|
|
|
|
|
# Bool indicating whether this object is currently tracing the layer call
|
|
|
|
|
# functions.
|
|
|
|
|
self.tracing = False
|
|
|
|
|
|
|
|
|
|
def _generate_input_signature(self, layer):
|
|
|
|
|
"""Inspects layer object and returns the inferred input signature.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
layer: Layer object.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of possibly nested TensorSpecs of the layer call function inputs.
|
|
|
|
|
The list does not contain the `training` argument.
|
|
|
|
|
"""
|
|
|
|
|
if (isinstance(layer.call, def_function.Function) and
|
|
|
|
|
layer.call.input_signature is not None):
|
|
|
|
|
return layer.call.input_signature
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(layer, training_lib.Model):
|
|
|
|
|
return saving_utils.model_input_signature(layer)
|
|
|
|
|
elif layer.input_spec is not None:
|
|
|
|
|
|
|
|
|
|
def to_tensor_spec_or_none(x):
|
|
|
|
|
spec = input_spec.to_tensor_spec(x, layer.dtype)
|
|
|
|
|
# If the shape is too general (e.g. multiple dimensions are allowed),
|
|
|
|
|
# return None so that separate functions can be generated for each
|
|
|
|
|
# inferred input signature.
|
|
|
|
|
# TODO(b/134962016): currently partial signatures are not supported.
|
|
|
|
|
if spec.shape == tensor_shape.TensorShape(None):
|
|
|
|
|
return None
|
|
|
|
|
return spec
|
|
|
|
|
input_signature = [nest.map_structure(
|
|
|
|
|
to_tensor_spec_or_none, layer.input_spec)]
|
|
|
|
|
|
|
|
|
|
return input_signature
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def add_trace(self, *args, **kwargs):
|
|
|
|
|
"""Traces all functions with the same args and kwargs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*args: Positional args passed to the original function.
|
|
|
|
|
**kwargs: Keyword args passed to the original function.
|
|
|
|
|
"""
|
|
|
|
|
kwargs = kwargs.copy()
|
|
|
|
|
self.tracing = True
|
|
|
|
|
for fn in self._functions.values():
|
|
|
|
|
# TODO(kathywu): Replace arguments with broader shapes defined in the
|
|
|
|
|
# input signature.
|
|
|
|
|
if self._expects_training_arg:
|
|
|
|
|
kwargs['training'] = False
|
|
|
|
|
fn.original_get_concrete_function(*args, **kwargs)
|
|
|
|
|
kwargs['training'] = True
|
|
|
|
|
fn.original_get_concrete_function(*args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
fn.original_get_concrete_function(*args, **kwargs)
|
|
|
|
|
self.tracing = False
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def fn_input_signature(self):
|
|
|
|
|
"""Returns input signature for the wrapped layer call function."""
|
|
|
|
|
if self._expects_training_arg:
|
|
|
|
|
# The training arg is left as a python boolean, so the call functions
|
|
|
|
|
# will not have an input signature (input signatures may only describe
|
|
|
|
|
# tensor arguments).
|
|
|
|
|
return None
|
|
|
|
|
if None in nest.flatten(self._input_signature):
|
|
|
|
|
# TODO(b/134962016): If input signature cannot be partially defined.
|
|
|
|
|
return None
|
|
|
|
|
return self._input_signature
|
|
|
|
|
|
|
|
|
|
def add_function(self, python_function, name):
|
|
|
|
|
"""Adds a layer call function to the collection."""
|
|
|
|
|
self._functions[name] = fn = LayerCall(
|
|
|
|
|
self, python_function, name,
|
|
|
|
|
input_signature=self.fn_input_signature)
|
|
|
|
|
|
|
|
|
|
if (None not in nest.flatten(self._input_signature) and
|
|
|
|
|
self._expects_training_arg):
|
|
|
|
|
# Manually add traces for layers that expect a training argument and have
|
|
|
|
|
# a fully defined input signature.
|
|
|
|
|
self.add_trace(*self._input_signature)
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerCall(def_function.Function):
|
|
|
|
|
"""Function that triggers traces of other functions in the same collection."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, call_collection, *args, **kwargs):
|
|
|
|
|
super(LayerCall, self).__init__(*args, **kwargs)
|
|
|
|
|
self.call_collection = call_collection
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
if not self.call_collection.tracing:
|
|
|
|
|
self.call_collection.add_trace(*args, **kwargs)
|
|
|
|
|
return super(LayerCall, self).__call__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def get_concrete_function(self, *args, **kwargs):
|
|
|
|
|
if not self.call_collection.tracing:
|
|
|
|
|
self.call_collection.add_trace(*args, **kwargs)
|
|
|
|
|
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def original_get_concrete_function(self, *args, **kwargs):
|
|
|
|
|
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_call_and_conditional_losses(layer):
|
|
|
|
|
"""Wraps call function that returns a tuple of (outputs, losses).
|
|
|
|
|
|
|
|
|
@ -966,51 +1133,19 @@ def _wrap_call_and_conditional_losses(layer):
|
|
|
|
|
layer: a Keras layer object
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
call function that returns outputs and conditional losses -- excludes
|
|
|
|
|
python call function that returns outputs and conditional losses -- excludes
|
|
|
|
|
activity regularizer
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(layer, RevivedLayer):
|
|
|
|
|
return layer.keras_api.call_and_return_conditional_losses
|
|
|
|
|
|
|
|
|
|
if (isinstance(layer.call, def_function.Function) and
|
|
|
|
|
layer.call.input_signature is not None):
|
|
|
|
|
input_signature = layer.call.input_signature
|
|
|
|
|
else:
|
|
|
|
|
if (isinstance(layer, training_lib.Model) and
|
|
|
|
|
saving_utils.model_input_signature(layer) is not None):
|
|
|
|
|
input_signature = saving_utils.model_input_signature(layer)
|
|
|
|
|
elif layer.input_spec is not None:
|
|
|
|
|
input_signature = [nest.map_structure(
|
|
|
|
|
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
|
|
|
|
|
layer.input_spec)]
|
|
|
|
|
# If input spec is too general, then don't define an input signature.
|
|
|
|
|
for spec in nest.flatten(input_signature):
|
|
|
|
|
if spec.shape == tensor_shape.TensorShape(None):
|
|
|
|
|
input_signature = None
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
input_signature = None
|
|
|
|
|
|
|
|
|
|
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
|
|
|
|
|
input_signature.append(
|
|
|
|
|
tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool))
|
|
|
|
|
|
|
|
|
|
# Create function that generates both outputs and losses
|
|
|
|
|
layer_call = layer.call
|
|
|
|
|
if layer._expects_training_arg: # pylint: disable=protected-access
|
|
|
|
|
def call_and_return_conditional_losses(inputs, training):
|
|
|
|
|
_set_symbolic_learning_phase(training)
|
|
|
|
|
def call_and_return_conditional_losses(inputs, training=False):
|
|
|
|
|
return layer_call(inputs, training=training), layer.get_losses_for(inputs)
|
|
|
|
|
else:
|
|
|
|
|
def call_and_return_conditional_losses(inputs):
|
|
|
|
|
K.set_learning_phase(0)
|
|
|
|
|
return layer_call(inputs), layer.get_losses_for(inputs)
|
|
|
|
|
return def_function.Function(
|
|
|
|
|
call_and_return_conditional_losses,
|
|
|
|
|
'{}_layer_call_and_return_conditional_losses'.format(layer.name),
|
|
|
|
|
input_signature=input_signature,
|
|
|
|
|
# TODO(kathywu): Investigate autograph error.
|
|
|
|
|
autograph=False)
|
|
|
|
|
return call_and_return_conditional_losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
|
|
|
|
@ -1018,50 +1153,22 @@ def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
|
|
|
|
|
if isinstance(layer, RevivedLayer):
|
|
|
|
|
return layer.keras_api.__call__ # pylint: disable=protected-access
|
|
|
|
|
if layer._expects_training_arg: # pylint: disable=protected-access
|
|
|
|
|
def call(inputs, training):
|
|
|
|
|
return call_and_return_conditional_losses(inputs, training)[0]
|
|
|
|
|
def call(inputs, training=False):
|
|
|
|
|
return call_and_return_conditional_losses(inputs, training=training)[0]
|
|
|
|
|
else:
|
|
|
|
|
def call(inputs):
|
|
|
|
|
return call_and_return_conditional_losses(inputs)[0]
|
|
|
|
|
return def_function.Function(
|
|
|
|
|
call, '{}_layer_call_fn'.format(layer.name),
|
|
|
|
|
input_signature=call_and_return_conditional_losses.input_signature,
|
|
|
|
|
# TODO(kathywu): Investigate autograph error.
|
|
|
|
|
autograph=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_symbolic_learning_phase(value):
|
|
|
|
|
"""Set learning phase to a tensor value (for internal use only).
|
|
|
|
|
|
|
|
|
|
This is used when wrapping call functions as tf.functions that have training
|
|
|
|
|
as a tensor input. Thus, when `learning_phase()` is called, the training
|
|
|
|
|
tensor is returned. This function is called when saving a model to SavedModel.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
value: A Tensor object.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If the input value is not a graph tensor
|
|
|
|
|
"""
|
|
|
|
|
graph = K.get_graph()
|
|
|
|
|
if not isinstance(value, ops.Tensor):
|
|
|
|
|
raise ValueError('Symbolic learning phase must be a graph tensor.')
|
|
|
|
|
K._GRAPH_LEARNING_PHASES[graph] = value # pylint: disable=protected-access
|
|
|
|
|
return call
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_activity_regularizer_loss(
|
|
|
|
|
layer, call_fn_with_losses, activity_regularizer_fn):
|
|
|
|
|
call_fn_with_losses, activity_regularizer_fn):
|
|
|
|
|
"""Appends activity regularizer loss to losses returned by the wrapped fn."""
|
|
|
|
|
def fn(*args):
|
|
|
|
|
outputs, losses = call_fn_with_losses(*args)
|
|
|
|
|
def fn(*args, **kwargs):
|
|
|
|
|
outputs, losses = call_fn_with_losses(*args, **kwargs)
|
|
|
|
|
losses.append(activity_regularizer_fn(outputs))
|
|
|
|
|
return outputs, losses
|
|
|
|
|
return def_function.Function(
|
|
|
|
|
fn,
|
|
|
|
|
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name),
|
|
|
|
|
input_signature=call_fn_with_losses.input_signature,
|
|
|
|
|
# TODO(kathywu): Investigate autograph error.
|
|
|
|
|
autograph=False)
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_unconditional_loss(loss_fn, index):
|
|
|
|
@ -1135,9 +1242,11 @@ class KerasObjectLoader(load.Loader):
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
for node in self._nodes:
|
|
|
|
|
if isinstance(node, RevivedModel):
|
|
|
|
|
input_signature = (
|
|
|
|
|
node.keras_api.call_and_return_conditional_losses.input_signature[0]
|
|
|
|
|
)
|
|
|
|
|
call_fn = node.keras_api.call_and_return_conditional_losses
|
|
|
|
|
if call_fn.input_signature is None:
|
|
|
|
|
inputs = infer_inputs_from_restored_call_function(call_fn)
|
|
|
|
|
else:
|
|
|
|
|
inputs = call_fn.input_signature[0]
|
|
|
|
|
if isinstance(node, RevivedSequential):
|
|
|
|
|
with trackable.no_automatic_dependency_tracking_scope(node):
|
|
|
|
|
node._layers = []
|
|
|
|
@ -1147,7 +1256,7 @@ class KerasObjectLoader(load.Loader):
|
|
|
|
|
if not node.inputs:
|
|
|
|
|
# Since this revived object is technically a subclassed model (even if
|
|
|
|
|
# the original model is functional/sequential), inputs should be set.
|
|
|
|
|
node._set_inputs(input_signature)
|
|
|
|
|
node._set_inputs(inputs)
|
|
|
|
|
if isinstance(node, RevivedLayer):
|
|
|
|
|
losses = node._serialized_attributes.get('regularization_losses', [])
|
|
|
|
|
for loss in losses:
|
|
|
|
@ -1276,6 +1385,26 @@ def recursively_deserialize_keras_object(config, module_objects=None):
|
|
|
|
|
raise ValueError('Unable to decode config: {}'.format(config))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_inputs_from_restored_call_function(fn):
|
|
|
|
|
"""Returns TensorSpec of inputs from a restored call function.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fn: Restored layer call function. It is assumed that the inputs are entirely
|
|
|
|
|
in the first argument.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
TensorSpec of call function inputs.
|
|
|
|
|
"""
|
|
|
|
|
def common_spec(x, y):
|
|
|
|
|
return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape),
|
|
|
|
|
x.dtype, x.name)
|
|
|
|
|
spec = fn.concrete_functions[0].structured_input_signature[0][0]
|
|
|
|
|
for concrete in fn.concrete_functions[1:]:
|
|
|
|
|
spec2 = concrete.structured_input_signature[0][0]
|
|
|
|
|
spec = nest.map_structure(common_spec, spec, spec2)
|
|
|
|
|
return spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RevivedNetwork(RevivedLayer):
|
|
|
|
|
"""Keras network of layers loaded from a SavedModel."""
|
|
|
|
|
|
|
|
|
|