Set SavedModel as default format for model.save in TF2, and compile loaded model.
Additional change: - Sequential models are now revived as Sequential. PiperOrigin-RevId: 251723025
This commit is contained in:
parent
b9c7a8c6e9
commit
9f7f717179
@ -145,8 +145,11 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
|
||||
# ensure every worker has a unique path. Note that in normal use case the
|
||||
# saving_filepath will be the same for all workers, but we use different
|
||||
# ones here just to test out chief saves checkpoint but non-chief doesn't.
|
||||
|
||||
# TODO(b/134551335): Must save to hdf5 until bug with copying
|
||||
# MirroredVariables is resolved.
|
||||
saving_filepath = os.path.join(
|
||||
test_obj.get_temp_dir(), 'checkpoint_%s_%d' %
|
||||
test_obj.get_temp_dir(), 'checkpoint_%s_%d.h5' %
|
||||
(test_base.get_task_type(), test_base.get_task_index()))
|
||||
|
||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||
|
@ -361,3 +361,7 @@ class Sequential(training.Model):
|
||||
if self.layers and hasattr(self.layers[0], 'input_spec'):
|
||||
return self.layers[0].input_spec
|
||||
return None
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
return '_tf_keras_sequential'
|
||||
|
@ -40,7 +40,6 @@ class KerasIntegrationTest(keras_parameterized.TestCase):
|
||||
fpath = os.path.join(self.temp_dir,
|
||||
'test_model_%s' % (random.randint(0, 1e7),))
|
||||
if context.executing_eagerly():
|
||||
keras.saving.save._KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = False
|
||||
save_format = 'tf'
|
||||
else:
|
||||
if (not isinstance(model, keras.Sequential) and
|
||||
@ -155,8 +154,19 @@ class SequentialIntegrationTest(KerasIntegrationTest):
|
||||
validation_data=(x_train, y_train),
|
||||
verbose=2)
|
||||
model = self._save_and_reload_model(model)
|
||||
|
||||
# TODO(b/134537740): model.pop doesn't update model outputs properly when
|
||||
# model.outputs is already defined, so just set to `None` for now.
|
||||
model.inputs = None
|
||||
model.outputs = None
|
||||
|
||||
model.pop()
|
||||
model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax'))
|
||||
|
||||
# TODO(b/134523282): There is an bug with Sequential models, so the model
|
||||
# must be marked as compiled=False to ensure the next compile goes through.
|
||||
model._is_compiled = False
|
||||
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizer_v2.adam.Adam(0.005),
|
||||
|
@ -26,7 +26,6 @@ import numpy as np
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.saving import model_config as model_config_lib
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
@ -146,31 +145,6 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
|
||||
if not custom_objects:
|
||||
custom_objects = {}
|
||||
|
||||
def convert_custom_objects(obj):
|
||||
"""Handles custom object lookup.
|
||||
|
||||
Arguments:
|
||||
obj: object, dict, or list.
|
||||
|
||||
Returns:
|
||||
The same structure, where occurrences
|
||||
of a custom object name have been replaced
|
||||
with the custom object.
|
||||
"""
|
||||
if isinstance(obj, list):
|
||||
deserialized = []
|
||||
for value in obj:
|
||||
deserialized.append(convert_custom_objects(value))
|
||||
return deserialized
|
||||
if isinstance(obj, dict):
|
||||
deserialized = {}
|
||||
for key, value in obj.items():
|
||||
deserialized[key] = convert_custom_objects(value)
|
||||
return deserialized
|
||||
if obj in custom_objects:
|
||||
return custom_objects[obj]
|
||||
return obj
|
||||
|
||||
opened_new_file = not isinstance(filepath, h5py.File)
|
||||
if opened_new_file:
|
||||
f = h5py.File(filepath, mode='r')
|
||||
@ -198,29 +172,10 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
|
||||
'the model was *not* compiled. Compile it manually.')
|
||||
return model
|
||||
training_config = json.loads(training_config.decode('utf-8'))
|
||||
optimizer_config = training_config['optimizer_config']
|
||||
optimizer = optimizers.deserialize(
|
||||
optimizer_config, custom_objects=custom_objects)
|
||||
|
||||
# Recover loss functions and metrics.
|
||||
loss_config = training_config['loss'] # Deserialize loss class.
|
||||
if isinstance(loss_config, dict) and 'class_name' in loss_config:
|
||||
loss_config = losses.get(loss_config)
|
||||
loss = convert_custom_objects(loss_config)
|
||||
metrics = convert_custom_objects(training_config['metrics'])
|
||||
weighted_metrics = convert_custom_objects(
|
||||
training_config.get('weighted_metrics', None))
|
||||
sample_weight_mode = training_config['sample_weight_mode']
|
||||
loss_weights = training_config['loss_weights']
|
||||
|
||||
# Compile model.
|
||||
model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
weighted_metrics=weighted_metrics,
|
||||
loss_weights=loss_weights,
|
||||
sample_weight_mode=sample_weight_mode)
|
||||
model.compile(**saving_utils.compile_args_from_training_config(
|
||||
training_config, custom_objects))
|
||||
|
||||
# Set optimizer weights.
|
||||
if 'optimizer_weights' in f:
|
||||
|
@ -23,7 +23,6 @@ import os
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving import saved_model
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
@ -77,32 +76,16 @@ def save_model(model,
|
||||
location, or instead ask the user with a manual prompt.
|
||||
include_optimizer: If True, save optimizer's state together.
|
||||
save_format: Either 'tf' or 'h5', indicating whether to save the model
|
||||
to Tensorflow SavedModel or HDF5. The 'tf' option is currently disabled,
|
||||
and will be enabled when Keras SavedModel export is no longer
|
||||
experimental. (The experimental function is
|
||||
tf.keras.experimental.export_saved_model).
|
||||
to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5'
|
||||
in TF 1.X.
|
||||
|
||||
Raises:
|
||||
ImportError: If save format is hdf5, and h5py is not available.
|
||||
"""
|
||||
from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
|
||||
|
||||
if (not tf2.enabled() and
|
||||
not ops.executing_eagerly_outside_functions()
|
||||
and save_format == 'tf'):
|
||||
raise NotImplementedError(
|
||||
'Saving the model as SavedModel is not supported in TensorFlow 1.X'
|
||||
'graph mode. Please enable eager execution or use the "h5" save format.'
|
||||
)
|
||||
|
||||
if _KERAS_SAVED_MODEL_STILL_EXPERIMENTAL and save_format == 'tf':
|
||||
raise NotImplementedError(
|
||||
'Saving the model as SavedModel is still in experimental stages. '
|
||||
'Please use tf.keras.experimental.export_saved_model, or use '
|
||||
'save_format="h5" to save to HDF5.')
|
||||
|
||||
# TODO(kathywu): Remove this when Keras SavedModel is not experimental.
|
||||
save_format = 'h5'
|
||||
default_format = 'tf' if tf2.enabled() else 'h5'
|
||||
save_format = save_format or default_format
|
||||
|
||||
if (save_format == 'h5' or
|
||||
(h5py is not None and isinstance(filepath, h5py.File)) or
|
||||
@ -119,7 +102,8 @@ def save_model(model,
|
||||
'or using `save_weights`.')
|
||||
hdf5_format.save_model_to_hdf5(
|
||||
model, filepath, overwrite, include_optimizer)
|
||||
return
|
||||
else:
|
||||
saved_model.save(model, filepath, overwrite, include_optimizer)
|
||||
|
||||
|
||||
@keras_export('keras.models.load_model')
|
||||
@ -148,14 +132,13 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
||||
ImportError: if loading from an hdf5 file and h5py is not available.
|
||||
IOError: In case of an invalid savefile.
|
||||
"""
|
||||
if not tf2.enabled() or (
|
||||
h5py is not None and (
|
||||
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
if (h5py is not None and (
|
||||
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
|
||||
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model.load_from_saved_model_v2(filepath)
|
||||
return saved_model.load_from_saved_model_v2(filepath, compile)
|
||||
|
||||
raise IOError(
|
||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
@ -43,28 +44,35 @@ class TestSaveModel(test.TestCase):
|
||||
'Model saved at path {} is not a valid hdf5 file.'
|
||||
.format(path))
|
||||
|
||||
def assert_saved_model(self, path):
|
||||
loader_impl.parse_saved_model(path)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_save_format_defaults(self):
|
||||
path = os.path.join(self.get_temp_dir(), 'model_path')
|
||||
|
||||
# The default is currently HDF5 no matter what the filepath is.
|
||||
save.save_model(self.model, path)
|
||||
self.assert_h5_format(path)
|
||||
self.assert_saved_model(path)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_save_hdf5(self):
|
||||
path = os.path.join(self.get_temp_dir(), 'model')
|
||||
save.save_model(self.model, path, save_format='h5')
|
||||
|
||||
self.assert_h5_format(path)
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'requires the model to be a Functional model or a Sequential model.'):
|
||||
save.save_model(self.subclassed_model, path, save_format='h5')
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_save_tf(self):
|
||||
path = os.path.join(self.get_temp_dir(), 'model')
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Saving the model as SavedModel is still in experimental stages.'):
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
self.assert_saved_model(path)
|
||||
with self.assertRaisesRegexp(ValueError, 'input shapes have not been set'):
|
||||
save.save_model(self.subclassed_model, path, save_format='tf')
|
||||
self.subclassed_model.predict(np.random.random((3, 5)))
|
||||
save.save_model(self.subclassed_model, path, save_format='tf')
|
||||
self.assert_saved_model(path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import model_from_json
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -54,6 +55,7 @@ from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.training.tracking.tracking import AutoTrackable
|
||||
from tensorflow.python.training.tracking.tracking import delete_tracking
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
@ -81,6 +83,10 @@ sequential = LazyLoader(
|
||||
training_lib = LazyLoader(
|
||||
"training_lib", globals(),
|
||||
"tensorflow.python.keras.engine.training")
|
||||
input_layer = LazyLoader(
|
||||
"input_layer", globals(),
|
||||
"tensorflow.python.keras.engine.input_layer")
|
||||
|
||||
# pylint:enable=g-inconsistent-quotes
|
||||
|
||||
|
||||
@ -713,7 +719,7 @@ def serialize_all_attributes(layer, serialization_cache):
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.warning('Skipping full serialization of object {}, because an '
|
||||
'error occurred while tracing layer functions. Error '
|
||||
'message: {}'.format(layer, e.message))
|
||||
'message: {}'.format(layer, e))
|
||||
else:
|
||||
# Add checkpointable objects and functions to the SerializedAttribute object
|
||||
# only if all functions are successfully traced.
|
||||
@ -743,10 +749,6 @@ def _should_skip_serialization(layer):
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if not layer.input_spec:
|
||||
logging.warning('Skipping full serialization of Keras layer {}, because '
|
||||
'it does not have an input spec defined.'.format(layer))
|
||||
return True
|
||||
if not layer.built:
|
||||
logging.warning('Skipping full serialization of Keras layer {}, because '
|
||||
'it is not built.'.format(layer))
|
||||
@ -771,8 +773,7 @@ def _wrap_layer_objects(layer, serialization_cache):
|
||||
# First, generate list of all regularization losses in this layer and
|
||||
# sublayers.
|
||||
regularization_losses = layer._callable_losses[:] # pylint: disable=protected-access
|
||||
for child_layer in (
|
||||
trackable_layer_utils.filter_empty_layer_containers(layer._layers)): # pylint: disable=protected-access
|
||||
for child_layer in _list_all_layers(layer):
|
||||
regularization_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access
|
||||
# Next, wrap all loss functions as tf.functions. Use the serialization cache
|
||||
# to store already-wrapped functions.
|
||||
@ -791,9 +792,7 @@ def _wrap_layer_objects(layer, serialization_cache):
|
||||
layer.trainable_variables),
|
||||
non_trainable_variables=data_structures.ListWrapper(
|
||||
layer.non_trainable_variables),
|
||||
layers=data_structures.ListWrapper(
|
||||
trackable_layer_utils.filter_empty_layer_containers(
|
||||
layer._layers)), # pylint: disable=protected-access
|
||||
layers=data_structures.ListWrapper(_list_all_layers(layer)),
|
||||
metrics=data_structures.ListWrapper(layer.metrics),
|
||||
regularization_losses=data_structures.ListWrapper(
|
||||
wrapped_loss_functions))
|
||||
@ -857,6 +856,13 @@ def _wrap_layer_functions(layer, serialization_cache,
|
||||
return fns
|
||||
|
||||
|
||||
def _list_all_layers(obj):
|
||||
if isinstance(obj, training_lib.Model):
|
||||
return obj.layers
|
||||
else:
|
||||
return trackable_layer_utils.filter_empty_layer_containers(obj._layers) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _replace_child_layer_functions(layer, serialization_cache):
|
||||
"""Replaces functions in the children layers with wrapped tf.functions.
|
||||
|
||||
@ -882,8 +888,7 @@ def _replace_child_layer_functions(layer, serialization_cache):
|
||||
}
|
||||
"""
|
||||
original_attrs = {}
|
||||
for child_layer in trackable_layer_utils.filter_empty_layer_containers(
|
||||
layer._layers): # pylint: disable=protected-access
|
||||
for child_layer in _list_all_layers(layer):
|
||||
# Save symbolic layer losses, which will be restored to maintain the same
|
||||
# state.
|
||||
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
|
||||
@ -934,14 +939,15 @@ def _use_wrapped_call(layer, call_fn):
|
||||
function that calls call_fn and returns the outputs. Losses returned by
|
||||
call_fn are added to the layer losses.
|
||||
"""
|
||||
def wrapped_call(inputs, *args, **kwargs):
|
||||
# TODO(kathywu): Support mask argument and multi-input call functions.
|
||||
def wrapped_call(inputs, **kwargs):
|
||||
"""Returns the outputs from the call_fn, and adds the losses."""
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
training = kwargs.pop('training', None)
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
training = math_ops.cast(training, dtypes.bool)
|
||||
outputs, losses = call_fn(inputs, training=training, *args, **kwargs)
|
||||
outputs, losses = call_fn(inputs, training=training)
|
||||
else:
|
||||
outputs, losses = call_fn(inputs)
|
||||
layer.add_loss(losses, inputs)
|
||||
@ -963,7 +969,7 @@ def _wrap_call_and_conditional_losses(layer):
|
||||
activity regularizer
|
||||
"""
|
||||
if isinstance(layer, RevivedLayer):
|
||||
return layer.call_and_return_conditional_losses
|
||||
return layer.keras_api.call_and_return_conditional_losses
|
||||
|
||||
if (isinstance(layer.call, def_function.Function) and
|
||||
layer.call.input_signature is not None):
|
||||
@ -972,7 +978,7 @@ def _wrap_call_and_conditional_losses(layer):
|
||||
if (isinstance(layer, training_lib.Model) and
|
||||
saving_utils.model_input_signature(layer) is not None):
|
||||
input_signature = saving_utils.model_input_signature(layer)
|
||||
else:
|
||||
elif layer.input_spec is not None:
|
||||
input_signature = [nest.map_structure(
|
||||
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
|
||||
layer.input_spec)]
|
||||
@ -981,6 +987,8 @@ def _wrap_call_and_conditional_losses(layer):
|
||||
if spec.shape == tensor_shape.TensorShape(None):
|
||||
input_signature = None
|
||||
break
|
||||
else:
|
||||
input_signature = None
|
||||
|
||||
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
|
||||
input_signature.append(
|
||||
@ -1007,7 +1015,7 @@ def _wrap_call_and_conditional_losses(layer):
|
||||
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
|
||||
"""Returns a function that returns only call function outputs."""
|
||||
if isinstance(layer, RevivedLayer):
|
||||
return layer._original_call # pylint: disable=protected-access
|
||||
return layer.keras_api.__call__ # pylint: disable=protected-access
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
def call(inputs, training):
|
||||
return call_and_return_conditional_losses(inputs, training)[0]
|
||||
@ -1076,7 +1084,7 @@ def _wrap_activity_regularizer(layer):
|
||||
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
|
||||
|
||||
|
||||
def load_from_saved_model_v2(path):
|
||||
def load_from_saved_model_v2(path, compile=True): # pylint: disable=redefined-builtin
|
||||
"""Loads Keras objects from a SavedModel.
|
||||
|
||||
Any Keras layer or model saved to the SavedModel will be loaded back
|
||||
@ -1092,13 +1100,27 @@ def load_from_saved_model_v2(path):
|
||||
|
||||
Args:
|
||||
path: Path to SavedModel.
|
||||
compile: If true, compile the model after loading it.
|
||||
|
||||
Returns:
|
||||
Object loaded from SavedModel.
|
||||
"""
|
||||
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
|
||||
# TODO(kathywu): Add code to load from objects that contain all endpoints
|
||||
return load.load_internal(path, loader_cls=KerasObjectLoader)
|
||||
model = load.load_internal(path, loader_cls=KerasObjectLoader)
|
||||
|
||||
if isinstance(model, RevivedModel) and compile:
|
||||
# TODO(kathywu): Use compiled objects from SavedModel, instead of
|
||||
# creating new objects from the training config.
|
||||
if model._training_config is not None: # pylint: disable=protected-access
|
||||
model.compile(**saving_utils.compile_args_from_training_config(
|
||||
model._training_config)) # pylint: disable=protected-access
|
||||
|
||||
return model
|
||||
|
||||
PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
|
||||
CommonEndpoints.all_checkpointable_objects)
|
||||
PUBLIC_ATTRIBUTES.add(_KERAS_ATTR)
|
||||
|
||||
|
||||
class KerasObjectLoader(load.Loader):
|
||||
@ -1111,6 +1133,20 @@ class KerasObjectLoader(load.Loader):
|
||||
def _finalize(self):
|
||||
# pylint: disable=protected-access
|
||||
for node in self._nodes:
|
||||
if isinstance(node, RevivedModel):
|
||||
input_signature = (
|
||||
node.keras_api.call_and_return_conditional_losses.input_signature[0]
|
||||
)
|
||||
if isinstance(node, RevivedSequential):
|
||||
with trackable.no_automatic_dependency_tracking_scope(node):
|
||||
node._layers = []
|
||||
for layer in node.keras_api.layers:
|
||||
node.add(layer)
|
||||
|
||||
if not node.inputs:
|
||||
# Since this revived object is technically a subclassed model (even if
|
||||
# the original model is functional/sequential), inputs should be set.
|
||||
node._set_inputs(input_signature)
|
||||
if isinstance(node, RevivedLayer):
|
||||
losses = node._serialized_attributes.get('regularization_losses', [])
|
||||
for loss in losses:
|
||||
@ -1122,20 +1158,25 @@ class KerasObjectLoader(load.Loader):
|
||||
node.activity_regularizer = getattr(node.keras_api,
|
||||
'activity_regularizer_fn', None)
|
||||
|
||||
if isinstance(node, RevivedModel):
|
||||
# Since this revived object is technically a subclassed model (even if
|
||||
# the original model is functional/sequential), inputs should be set.
|
||||
input_signature = (
|
||||
node.keras_api.call_and_return_conditional_losses.input_signature[0]
|
||||
)
|
||||
node._set_inputs(input_signature)
|
||||
# Now that the node object has been fully loaded and restored from the,
|
||||
# checkpoint, the object no longer needs to track objects added from
|
||||
# SerializedAttributes. (Note that saving a training checkpoint still
|
||||
# functions correctly, because layers and variables are tracked
|
||||
# separately by the Layer object.)
|
||||
# TODO(kathywu): Instead of outright deleting these nodes (which would
|
||||
# make restoring from a different checkpoint tricky), mark them as extra
|
||||
# dependencies that are OK to overwrite.
|
||||
for name in PUBLIC_ATTRIBUTES:
|
||||
delete_tracking(node, name)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _recreate_base_user_object(self, proto):
|
||||
revived_classes = {
|
||||
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
|
||||
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
|
||||
'_tf_keras_model': (RevivedModel, training_lib.Model)
|
||||
'_tf_keras_model': (RevivedModel, training_lib.Model),
|
||||
'_tf_keras_sequential': (RevivedSequential, models_lib.Sequential)
|
||||
}
|
||||
|
||||
parent_classes = revived_classes.get(proto.identifier, None)
|
||||
@ -1193,9 +1234,9 @@ class RevivedLayer(object):
|
||||
|
||||
def _revive_setter(self, name, value):
|
||||
"""Reattaches attributes from the SavedModel to the newly revived object."""
|
||||
if (name in CommonEndpoints.all_functions or
|
||||
name in CommonEndpoints.all_checkpointable_objects or
|
||||
name == _KERAS_ATTR):
|
||||
if name in PUBLIC_ATTRIBUTES:
|
||||
if isinstance(value, trackable.Trackable):
|
||||
self._track_trackable(value, name=name)
|
||||
self._serialized_attributes[name] = value
|
||||
else:
|
||||
setattr(self, name, value)
|
||||
@ -1258,7 +1299,50 @@ class RevivedModel(RevivedNetwork):
|
||||
revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata)
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
if 'training_config' in metadata:
|
||||
revived_obj._training_config = metadata['training_config'] # pylint:disable=protected-access
|
||||
revived_obj._training_config = metadata.get('training_config') # pylint:disable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
||||
|
||||
class RevivedSequential(RevivedModel):
|
||||
"""Keras sequential model loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived Sequential model from SavedModel metadata."""
|
||||
revived_obj = super(RevivedSequential, cls)._init_from_metadata(metadata)
|
||||
return revived_obj
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return models_lib.Sequential.call(self, *args, **kwargs)
|
||||
|
||||
|
||||
def save(model, filepath, overwrite, include_optimizer):
|
||||
"""Saves a model as a SavedModel to the filepath.
|
||||
|
||||
Args:
|
||||
model: Keras model instance to be saved.
|
||||
filepath: String path to save the model.
|
||||
overwrite: whether to overwrite the existing filepath.
|
||||
include_optimizer: If True, save the model's optimizer state.
|
||||
|
||||
Raises:
|
||||
ValueError: if the model's inputs have not been defined.
|
||||
"""
|
||||
# If file exists and should not be overwritten.
|
||||
if not overwrite and os.path.exists(filepath):
|
||||
proceed = ask_to_proceed_with_overwrite(filepath)
|
||||
if not proceed:
|
||||
return
|
||||
|
||||
if _should_skip_serialization(model):
|
||||
saving_utils.raise_model_input_error(model)
|
||||
|
||||
if not include_optimizer:
|
||||
orig_optimizer = model.optimizer
|
||||
model.optimizer = None
|
||||
|
||||
save_lib.save(model, filepath)
|
||||
|
||||
if not include_optimizer:
|
||||
model.optimizer = orig_optimizer
|
||||
|
@ -703,10 +703,6 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.evaluate(getattr(loaded.keras_api, attr)))
|
||||
self.assertLen(loaded.regularization_losses, 1)
|
||||
expected_layers = len(model.layers)
|
||||
if testing_utils.get_model_type() == 'sequential':
|
||||
# The autogenerated Input layer is hidden in the model.layers list,
|
||||
# but included in the loaded sub-layers.
|
||||
expected_layers += 1
|
||||
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
training_bool = constant_op.constant(False)
|
||||
@ -718,5 +714,39 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(*call_args)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_compiled_model(self):
|
||||
input_arr = np.random.random((1, 3))
|
||||
target_arr = np.random.random((1, 4))
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
expected_predict = model.predict(input_arr)
|
||||
|
||||
# Compile and save model.
|
||||
model.compile('rmsprop', 'mse')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
# TODO(b/134519980): Issue with model.fit if the model call function uses
|
||||
# a tf.function (Graph mode only).
|
||||
with context.eager_mode():
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
actual_predict = loaded.predict(input_arr)
|
||||
self.assertAllClose(expected_predict, actual_predict)
|
||||
|
||||
loss_before = loaded.evaluate(input_arr, target_arr)
|
||||
loaded.fit(input_arr, target_arr)
|
||||
loss_after = loaded.evaluate(input_arr, target_arr)
|
||||
self.assertLess(loss_after, loss_before)
|
||||
predict = loaded.predict(input_arr)
|
||||
|
||||
ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
|
||||
loaded.save_weights(ckpt_path)
|
||||
|
||||
# Ensure that the checkpoint is compatible with the original model.
|
||||
model.load_weights(ckpt_path)
|
||||
self.assertAllClose(predict, model.predict(input_arr))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -18,11 +18,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -94,6 +97,14 @@ def model_input_signature(model):
|
||||
return [input_specs]
|
||||
|
||||
|
||||
def raise_model_input_error(model):
|
||||
raise ValueError(
|
||||
'Model {} cannot be saved because the input shapes have not been '
|
||||
'set. Usually, input shapes are automatically determined from calling'
|
||||
' .fit() or .predict(). To manually set the shapes, call '
|
||||
'model._set_inputs(inputs).'.format(model))
|
||||
|
||||
|
||||
def trace_model_call(model, input_signature=None):
|
||||
"""Trace the model call to create a tf.function for exporting a Keras model.
|
||||
|
||||
@ -116,11 +127,7 @@ def trace_model_call(model, input_signature=None):
|
||||
input_signature = model_input_signature(model)
|
||||
|
||||
if input_signature is None:
|
||||
raise ValueError(
|
||||
'Model {} cannot be saved because the input shapes have not been '
|
||||
'set. Usually, input shapes are automatically determined from calling'
|
||||
' .fit() or .predict(). To manually set the shapes, call '
|
||||
'model._set_inputs(inputs).'.format(model))
|
||||
raise_model_input_error(model)
|
||||
|
||||
# TODO(mdan): Should the model's call be autographed by default?
|
||||
@def_function.function(input_signature=input_signature, autograph=False)
|
||||
@ -190,3 +197,43 @@ def model_metadata(model, include_optimizer=True, require_config=True):
|
||||
'config': model.optimizer.get_config()}
|
||||
metadata['training_config']['optimizer_config'] = optimizer_config
|
||||
return metadata
|
||||
|
||||
|
||||
def should_overwrite(filepath, overwrite):
|
||||
"""Returns whether the filepath should be overwritten."""
|
||||
# If file exists and should not be overwritten.
|
||||
if not overwrite and os.path.isfile(filepath):
|
||||
return ask_to_proceed_with_overwrite(filepath)
|
||||
return True
|
||||
|
||||
|
||||
def compile_args_from_training_config(training_config, custom_objects=None):
|
||||
"""Return model.compile arguments from training config."""
|
||||
if custom_objects is None:
|
||||
custom_objects = {}
|
||||
|
||||
optimizer_config = training_config['optimizer_config']
|
||||
optimizer = optimizers.deserialize(
|
||||
optimizer_config, custom_objects=custom_objects)
|
||||
|
||||
# Recover loss functions and metrics.
|
||||
loss_config = training_config['loss'] # Deserialize loss class.
|
||||
if isinstance(loss_config, dict) and 'class_name' in loss_config:
|
||||
loss_config = losses.get(loss_config)
|
||||
loss = nest.map_structure(
|
||||
lambda obj: custom_objects.get(obj, obj), loss_config)
|
||||
metrics = nest.map_structure(
|
||||
lambda obj: custom_objects.get(obj, obj), training_config['metrics'])
|
||||
weighted_metrics = nest.map_structure(
|
||||
lambda obj: custom_objects.get(obj, obj),
|
||||
training_config.get('weighted_metrics', None))
|
||||
sample_weight_mode = training_config['sample_weight_mode']
|
||||
loss_weights = training_config['loss_weights']
|
||||
|
||||
return dict(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
weighted_metrics=weighted_metrics,
|
||||
loss_weights=loss_weights,
|
||||
sample_weight_mode=sample_weight_mode)
|
||||
|
@ -81,13 +81,7 @@ class AutoTrackable(base.Trackable):
|
||||
|
||||
def __delattr__(self, name):
|
||||
self._maybe_initialize_trackable()
|
||||
if name in self._unconditional_dependency_names:
|
||||
del self._unconditional_dependency_names[name]
|
||||
for index, (dep_name, _) in enumerate(
|
||||
self._unconditional_checkpoint_dependencies):
|
||||
if dep_name == name:
|
||||
del self._unconditional_checkpoint_dependencies[index]
|
||||
break
|
||||
delete_tracking(self, name)
|
||||
super(AutoTrackable, self).__delattr__(name)
|
||||
|
||||
def _no_dependency(self, value):
|
||||
@ -110,6 +104,19 @@ class AutoTrackable(base.Trackable):
|
||||
return functions
|
||||
|
||||
|
||||
def delete_tracking(obj, name):
|
||||
"""Removes the tracking of name from object."""
|
||||
# pylint: disable=protected-access
|
||||
if name in obj._unconditional_dependency_names:
|
||||
del obj._unconditional_dependency_names[name]
|
||||
for index, (dep_name, _) in enumerate(
|
||||
obj._unconditional_checkpoint_dependencies):
|
||||
if dep_name == name:
|
||||
del obj._unconditional_checkpoint_dependencies[index]
|
||||
break
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class ResourceTracker(object):
|
||||
"""An object that tracks a list of resources."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user