Fix error with tracing function with extra argument. This appeared when training argument is passed in args instead of kwargs.
PiperOrigin-RevId: 257492170
This commit is contained in:
parent
05f0a6c5e0
commit
22f1546586
@ -22,3 +22,7 @@ from __future__ import print_function
|
||||
# e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an
|
||||
# object loaded from `tf.saved_model.load()`.
|
||||
KERAS_ATTR = 'keras_api'
|
||||
|
||||
# Keys for the serialization cache.
|
||||
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
|
||||
KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import load as keras_load
|
||||
from tensorflow.python.keras.saving.saved_model import serialized_attributes
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
@ -38,6 +39,7 @@ from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
@ -86,23 +88,18 @@ def save(model, filepath, overwrite, include_optimizer):
|
||||
model.optimizer = orig_optimizer
|
||||
|
||||
|
||||
# Keys for the serialization cache.
|
||||
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
|
||||
_KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
||||
|
||||
|
||||
def serialize_all_attributes(layer, serialization_cache):
|
||||
"""Serialize all attributes in the layer."""
|
||||
save_model_default_signature = False
|
||||
if _KERAS_CACHE_KEY not in serialization_cache:
|
||||
keras_cache = serialization_cache[_KERAS_CACHE_KEY] = {}
|
||||
if constants.KERAS_CACHE_KEY not in serialization_cache:
|
||||
keras_cache = serialization_cache[constants.KERAS_CACHE_KEY] = {}
|
||||
if isinstance(layer, training_lib.Model):
|
||||
# Only trace default signature if the root object is a Model. Since the
|
||||
# keras cache key is only created in this method, we know that the object
|
||||
# is root if the key does not yet exist in the cache.
|
||||
save_model_default_signature = True
|
||||
else:
|
||||
keras_cache = serialization_cache[_KERAS_CACHE_KEY]
|
||||
keras_cache = serialization_cache[constants.KERAS_CACHE_KEY]
|
||||
|
||||
if layer in keras_cache:
|
||||
return keras_cache[layer]
|
||||
@ -320,11 +317,12 @@ def _replace_child_layer_functions(layer, serialization_cache):
|
||||
# pylint: disable=protected-access
|
||||
original_fns = {}
|
||||
for child_layer in _list_all_layers(layer):
|
||||
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
|
||||
if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
|
||||
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
|
||||
.functions)
|
||||
else:
|
||||
layer_fns = serialization_cache[_KERAS_CACHE_KEY][child_layer].functions
|
||||
layer_fns = (
|
||||
serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
|
||||
if not layer_fns:
|
||||
# This indicates either:
|
||||
# - circular dependency, which means the current layer's functions
|
||||
@ -442,15 +440,28 @@ class LayerCallCollection(object):
|
||||
*args: Positional args passed to the original function.
|
||||
**kwargs: Keyword args passed to the original function.
|
||||
"""
|
||||
args = list(args)
|
||||
kwargs = kwargs.copy()
|
||||
self.tracing = True
|
||||
for fn in self._functions.values():
|
||||
# TODO(kathywu): Replace arguments with broader shapes defined in the
|
||||
# input signature.
|
||||
if self._expects_training_arg:
|
||||
kwargs['training'] = False
|
||||
arg_list = tf_inspect.getfullargspec(fn.python_function).args
|
||||
if 'training' in arg_list:
|
||||
training_arg_index = arg_list.index('training')
|
||||
else:
|
||||
training_arg_index = -1
|
||||
|
||||
def set_training_arg(training, index=training_arg_index):
|
||||
if index >= 0 and len(args) > index:
|
||||
args[index] = training
|
||||
else:
|
||||
kwargs['training'] = training
|
||||
|
||||
set_training_arg(False)
|
||||
fn.original_get_concrete_function(*args, **kwargs)
|
||||
kwargs['training'] = True
|
||||
set_training_arg(True)
|
||||
fn.original_get_concrete_function(*args, **kwargs)
|
||||
else:
|
||||
fn.original_get_concrete_function(*args, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user