Fix error with tracing function with extra argument. This appeared when training argument is passed in args instead of kwargs.

PiperOrigin-RevId: 257492170
This commit is contained in:
Katherine Wu 2019-07-10 15:17:20 -07:00 committed by TensorFlower Gardener
parent 05f0a6c5e0
commit 22f1546586
2 changed files with 27 additions and 12 deletions

View File

@ -22,3 +22,7 @@ from __future__ import print_function
# e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an
# object loaded from `tf.saved_model.load()`.
KERAS_ATTR = 'keras_api'
# Keys for the serialization cache.
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
KERAS_CACHE_KEY = 'keras_serialized_attributes'

View File

@ -28,6 +28,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import load as keras_load
from tensorflow.python.keras.saving.saved_model import serialized_attributes
from tensorflow.python.keras.saving.saved_model import utils
@ -38,6 +39,7 @@ from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.lazy_loader import LazyLoader
# To avoid circular dependencies between keras/engine and keras/saving,
@ -86,23 +88,18 @@ def save(model, filepath, overwrite, include_optimizer):
model.optimizer = orig_optimizer
# Keys for the serialization cache.
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
_KERAS_CACHE_KEY = 'keras_serialized_attributes'
def serialize_all_attributes(layer, serialization_cache):
"""Serialize all attributes in the layer."""
save_model_default_signature = False
if _KERAS_CACHE_KEY not in serialization_cache:
keras_cache = serialization_cache[_KERAS_CACHE_KEY] = {}
if constants.KERAS_CACHE_KEY not in serialization_cache:
keras_cache = serialization_cache[constants.KERAS_CACHE_KEY] = {}
if isinstance(layer, training_lib.Model):
# Only trace default signature if the root object is a Model. Since the
# keras cache key is only created in this method, we know that the object
# is root if the key does not yet exist in the cache.
save_model_default_signature = True
else:
keras_cache = serialization_cache[_KERAS_CACHE_KEY]
keras_cache = serialization_cache[constants.KERAS_CACHE_KEY]
if layer in keras_cache:
return keras_cache[layer]
@ -320,11 +317,12 @@ def _replace_child_layer_functions(layer, serialization_cache):
# pylint: disable=protected-access
original_fns = {}
for child_layer in _list_all_layers(layer):
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
.functions)
else:
layer_fns = serialization_cache[_KERAS_CACHE_KEY][child_layer].functions
layer_fns = (
serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
if not layer_fns:
# This indicates either:
# - circular dependency, which means the current layer's functions
@ -442,15 +440,28 @@ class LayerCallCollection(object):
*args: Positional args passed to the original function.
**kwargs: Keyword args passed to the original function.
"""
args = list(args)
kwargs = kwargs.copy()
self.tracing = True
for fn in self._functions.values():
# TODO(kathywu): Replace arguments with broader shapes defined in the
# input signature.
if self._expects_training_arg:
kwargs['training'] = False
arg_list = tf_inspect.getfullargspec(fn.python_function).args
if 'training' in arg_list:
training_arg_index = arg_list.index('training')
else:
training_arg_index = -1
def set_training_arg(training, index=training_arg_index):
if index >= 0 and len(args) > index:
args[index] = training
else:
kwargs['training'] = training
set_training_arg(False)
fn.original_get_concrete_function(*args, **kwargs)
kwargs['training'] = True
set_training_arg(True)
fn.original_get_concrete_function(*args, **kwargs)
else:
fn.original_get_concrete_function(*args, **kwargs)