Set SavedModel as default format for model.save in TF2, and compile loaded model.

Additional change:
- Sequential models are now revived as Sequential.

PiperOrigin-RevId: 251723025
This commit is contained in:
Katherine Wu 2019-06-05 14:46:05 -07:00 committed by Gunhan Gulsoy
parent b9c7a8c6e9
commit 9f7f717179
10 changed files with 262 additions and 131 deletions

View File

@ -145,8 +145,11 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
# ensure every worker has a unique path. Note that in normal use case the
# saving_filepath will be the same for all workers, but we use different
# ones here just to test out chief saves checkpoint but non-chief doesn't.
# TODO(b/134551335): Must save to hdf5 until bug with copying
# MirroredVariables is resolved.
saving_filepath = os.path.join(
test_obj.get_temp_dir(), 'checkpoint_%s_%d' %
test_obj.get_temp_dir(), 'checkpoint_%s_%d.h5' %
(test_base.get_task_type(), test_base.get_task_index()))
# The saving_filepath shouldn't exist at the beginning (as it's unique).

View File

@ -361,3 +361,7 @@ class Sequential(training.Model):
if self.layers and hasattr(self.layers[0], 'input_spec'):
return self.layers[0].input_spec
return None
@property
def _object_identifier(self):
return '_tf_keras_sequential'

View File

@ -40,7 +40,6 @@ class KerasIntegrationTest(keras_parameterized.TestCase):
fpath = os.path.join(self.temp_dir,
'test_model_%s' % (random.randint(0, 1e7),))
if context.executing_eagerly():
keras.saving.save._KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = False
save_format = 'tf'
else:
if (not isinstance(model, keras.Sequential) and
@ -155,8 +154,19 @@ class SequentialIntegrationTest(KerasIntegrationTest):
validation_data=(x_train, y_train),
verbose=2)
model = self._save_and_reload_model(model)
# TODO(b/134537740): model.pop doesn't update model outputs properly when
# model.outputs is already defined, so just set to `None` for now.
model.inputs = None
model.outputs = None
model.pop()
model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax'))
# TODO(b/134523282): There is an bug with Sequential models, so the model
# must be marked as compiled=False to ensure the next compile goes through.
model._is_compiled = False
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizer_v2.adam.Adam(0.005),

View File

@ -26,7 +26,6 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.saving import model_config as model_config_lib
from tensorflow.python.keras.saving import saving_utils
@ -146,31 +145,6 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
if not custom_objects:
custom_objects = {}
def convert_custom_objects(obj):
"""Handles custom object lookup.
Arguments:
obj: object, dict, or list.
Returns:
The same structure, where occurrences
of a custom object name have been replaced
with the custom object.
"""
if isinstance(obj, list):
deserialized = []
for value in obj:
deserialized.append(convert_custom_objects(value))
return deserialized
if isinstance(obj, dict):
deserialized = {}
for key, value in obj.items():
deserialized[key] = convert_custom_objects(value)
return deserialized
if obj in custom_objects:
return custom_objects[obj]
return obj
opened_new_file = not isinstance(filepath, h5py.File)
if opened_new_file:
f = h5py.File(filepath, mode='r')
@ -198,29 +172,10 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
'the model was *not* compiled. Compile it manually.')
return model
training_config = json.loads(training_config.decode('utf-8'))
optimizer_config = training_config['optimizer_config']
optimizer = optimizers.deserialize(
optimizer_config, custom_objects=custom_objects)
# Recover loss functions and metrics.
loss_config = training_config['loss'] # Deserialize loss class.
if isinstance(loss_config, dict) and 'class_name' in loss_config:
loss_config = losses.get(loss_config)
loss = convert_custom_objects(loss_config)
metrics = convert_custom_objects(training_config['metrics'])
weighted_metrics = convert_custom_objects(
training_config.get('weighted_metrics', None))
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
# Compile model.
model.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode)
model.compile(**saving_utils.compile_args_from_training_config(
training_config, custom_objects))
# Set optimizer weights.
if 'optimizer_weights' in f:

View File

@ -23,7 +23,6 @@ import os
import six
from tensorflow.python import tf2
from tensorflow.python.framework import ops
from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import saved_model
from tensorflow.python.saved_model import loader_impl
@ -77,32 +76,16 @@ def save_model(model,
location, or instead ask the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
save_format: Either 'tf' or 'h5', indicating whether to save the model
to Tensorflow SavedModel or HDF5. The 'tf' option is currently disabled,
and will be enabled when Keras SavedModel export is no longer
experimental. (The experimental function is
tf.keras.experimental.export_saved_model).
to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5'
in TF 1.X.
Raises:
ImportError: If save format is hdf5, and h5py is not available.
"""
from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
if (not tf2.enabled() and
not ops.executing_eagerly_outside_functions()
and save_format == 'tf'):
raise NotImplementedError(
'Saving the model as SavedModel is not supported in TensorFlow 1.X'
'graph mode. Please enable eager execution or use the "h5" save format.'
)
if _KERAS_SAVED_MODEL_STILL_EXPERIMENTAL and save_format == 'tf':
raise NotImplementedError(
'Saving the model as SavedModel is still in experimental stages. '
'Please use tf.keras.experimental.export_saved_model, or use '
'save_format="h5" to save to HDF5.')
# TODO(kathywu): Remove this when Keras SavedModel is not experimental.
save_format = 'h5'
default_format = 'tf' if tf2.enabled() else 'h5'
save_format = save_format or default_format
if (save_format == 'h5' or
(h5py is not None and isinstance(filepath, h5py.File)) or
@ -119,7 +102,8 @@ def save_model(model,
'or using `save_weights`.')
hdf5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer)
return
else:
saved_model.save(model, filepath, overwrite, include_optimizer)
@keras_export('keras.models.load_model')
@ -148,14 +132,13 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
ImportError: if loading from an hdf5 file and h5py is not available.
IOError: In case of an invalid savefile.
"""
if not tf2.enabled() or (
h5py is not None and (
if (h5py is not None and (
isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model.load_from_saved_model_v2(filepath)
return saved_model.load_from_saved_model_v2(filepath, compile)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.saving import save
from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader_impl
try:
import h5py # pylint:disable=g-import-not-at-top
@ -43,28 +44,35 @@ class TestSaveModel(test.TestCase):
'Model saved at path {} is not a valid hdf5 file.'
.format(path))
def assert_saved_model(self, path):
loader_impl.parse_saved_model(path)
@test_util.run_v2_only
def test_save_format_defaults(self):
path = os.path.join(self.get_temp_dir(), 'model_path')
# The default is currently HDF5 no matter what the filepath is.
save.save_model(self.model, path)
self.assert_h5_format(path)
self.assert_saved_model(path)
@test_util.run_v2_only
def test_save_hdf5(self):
path = os.path.join(self.get_temp_dir(), 'model')
save.save_model(self.model, path, save_format='h5')
self.assert_h5_format(path)
with self.assertRaisesRegexp(
NotImplementedError,
'requires the model to be a Functional model or a Sequential model.'):
save.save_model(self.subclassed_model, path, save_format='h5')
@test_util.run_v2_only
def test_save_tf(self):
path = os.path.join(self.get_temp_dir(), 'model')
with self.assertRaisesRegexp(
NotImplementedError,
'Saving the model as SavedModel is still in experimental stages.'):
save.save_model(self.model, path, save_format='tf')
self.assert_saved_model(path)
with self.assertRaisesRegexp(ValueError, 'input shapes have not been set'):
save.save_model(self.subclassed_model, path, save_format='tf')
self.subclassed_model.predict(np.random.random((3, 5)))
save.save_model(self.subclassed_model, path, save_format='tf')
self.assert_saved_model(path)
if __name__ == '__main__':
test.main()

View File

@ -38,6 +38,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import model_from_json
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@ -54,6 +55,7 @@ from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.training.tracking.tracking import delete_tracking
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.lazy_loader import LazyLoader
@ -81,6 +83,10 @@ sequential = LazyLoader(
training_lib = LazyLoader(
"training_lib", globals(),
"tensorflow.python.keras.engine.training")
input_layer = LazyLoader(
"input_layer", globals(),
"tensorflow.python.keras.engine.input_layer")
# pylint:enable=g-inconsistent-quotes
@ -713,7 +719,7 @@ def serialize_all_attributes(layer, serialization_cache):
except (ValueError, TypeError) as e:
logging.warning('Skipping full serialization of object {}, because an '
'error occurred while tracing layer functions. Error '
'message: {}'.format(layer, e.message))
'message: {}'.format(layer, e))
else:
# Add checkpointable objects and functions to the SerializedAttribute object
# only if all functions are successfully traced.
@ -743,10 +749,6 @@ def _should_skip_serialization(layer):
else:
return False
else:
if not layer.input_spec:
logging.warning('Skipping full serialization of Keras layer {}, because '
'it does not have an input spec defined.'.format(layer))
return True
if not layer.built:
logging.warning('Skipping full serialization of Keras layer {}, because '
'it is not built.'.format(layer))
@ -771,8 +773,7 @@ def _wrap_layer_objects(layer, serialization_cache):
# First, generate list of all regularization losses in this layer and
# sublayers.
regularization_losses = layer._callable_losses[:] # pylint: disable=protected-access
for child_layer in (
trackable_layer_utils.filter_empty_layer_containers(layer._layers)): # pylint: disable=protected-access
for child_layer in _list_all_layers(layer):
regularization_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access
# Next, wrap all loss functions as tf.functions. Use the serialization cache
# to store already-wrapped functions.
@ -791,9 +792,7 @@ def _wrap_layer_objects(layer, serialization_cache):
layer.trainable_variables),
non_trainable_variables=data_structures.ListWrapper(
layer.non_trainable_variables),
layers=data_structures.ListWrapper(
trackable_layer_utils.filter_empty_layer_containers(
layer._layers)), # pylint: disable=protected-access
layers=data_structures.ListWrapper(_list_all_layers(layer)),
metrics=data_structures.ListWrapper(layer.metrics),
regularization_losses=data_structures.ListWrapper(
wrapped_loss_functions))
@ -857,6 +856,13 @@ def _wrap_layer_functions(layer, serialization_cache,
return fns
def _list_all_layers(obj):
if isinstance(obj, training_lib.Model):
return obj.layers
else:
return trackable_layer_utils.filter_empty_layer_containers(obj._layers) # pylint: disable=protected-access
def _replace_child_layer_functions(layer, serialization_cache):
"""Replaces functions in the children layers with wrapped tf.functions.
@ -882,8 +888,7 @@ def _replace_child_layer_functions(layer, serialization_cache):
}
"""
original_attrs = {}
for child_layer in trackable_layer_utils.filter_empty_layer_containers(
layer._layers): # pylint: disable=protected-access
for child_layer in _list_all_layers(layer):
# Save symbolic layer losses, which will be restored to maintain the same
# state.
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
@ -934,14 +939,15 @@ def _use_wrapped_call(layer, call_fn):
function that calls call_fn and returns the outputs. Losses returned by
call_fn are added to the layer losses.
"""
def wrapped_call(inputs, *args, **kwargs):
# TODO(kathywu): Support mask argument and multi-input call functions.
def wrapped_call(inputs, **kwargs):
"""Returns the outputs from the call_fn, and adds the losses."""
if layer._expects_training_arg: # pylint: disable=protected-access
training = kwargs.pop('training', None)
if training is None:
training = K.learning_phase()
training = math_ops.cast(training, dtypes.bool)
outputs, losses = call_fn(inputs, training=training, *args, **kwargs)
outputs, losses = call_fn(inputs, training=training)
else:
outputs, losses = call_fn(inputs)
layer.add_loss(losses, inputs)
@ -963,7 +969,7 @@ def _wrap_call_and_conditional_losses(layer):
activity regularizer
"""
if isinstance(layer, RevivedLayer):
return layer.call_and_return_conditional_losses
return layer.keras_api.call_and_return_conditional_losses
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
@ -972,7 +978,7 @@ def _wrap_call_and_conditional_losses(layer):
if (isinstance(layer, training_lib.Model) and
saving_utils.model_input_signature(layer) is not None):
input_signature = saving_utils.model_input_signature(layer)
else:
elif layer.input_spec is not None:
input_signature = [nest.map_structure(
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
layer.input_spec)]
@ -981,6 +987,8 @@ def _wrap_call_and_conditional_losses(layer):
if spec.shape == tensor_shape.TensorShape(None):
input_signature = None
break
else:
input_signature = None
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
input_signature.append(
@ -1007,7 +1015,7 @@ def _wrap_call_and_conditional_losses(layer):
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
"""Returns a function that returns only call function outputs."""
if isinstance(layer, RevivedLayer):
return layer._original_call # pylint: disable=protected-access
return layer.keras_api.__call__ # pylint: disable=protected-access
if layer._expects_training_arg: # pylint: disable=protected-access
def call(inputs, training):
return call_and_return_conditional_losses(inputs, training)[0]
@ -1076,7 +1084,7 @@ def _wrap_activity_regularizer(layer):
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
def load_from_saved_model_v2(path):
def load_from_saved_model_v2(path, compile=True): # pylint: disable=redefined-builtin
"""Loads Keras objects from a SavedModel.
Any Keras layer or model saved to the SavedModel will be loaded back
@ -1092,13 +1100,27 @@ def load_from_saved_model_v2(path):
Args:
path: Path to SavedModel.
compile: If true, compile the model after loading it.
Returns:
Object loaded from SavedModel.
"""
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
# TODO(kathywu): Add code to load from objects that contain all endpoints
return load.load_internal(path, loader_cls=KerasObjectLoader)
model = load.load_internal(path, loader_cls=KerasObjectLoader)
if isinstance(model, RevivedModel) and compile:
# TODO(kathywu): Use compiled objects from SavedModel, instead of
# creating new objects from the training config.
if model._training_config is not None: # pylint: disable=protected-access
model.compile(**saving_utils.compile_args_from_training_config(
model._training_config)) # pylint: disable=protected-access
return model
PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
CommonEndpoints.all_checkpointable_objects)
PUBLIC_ATTRIBUTES.add(_KERAS_ATTR)
class KerasObjectLoader(load.Loader):
@ -1111,6 +1133,20 @@ class KerasObjectLoader(load.Loader):
def _finalize(self):
# pylint: disable=protected-access
for node in self._nodes:
if isinstance(node, RevivedModel):
input_signature = (
node.keras_api.call_and_return_conditional_losses.input_signature[0]
)
if isinstance(node, RevivedSequential):
with trackable.no_automatic_dependency_tracking_scope(node):
node._layers = []
for layer in node.keras_api.layers:
node.add(layer)
if not node.inputs:
# Since this revived object is technically a subclassed model (even if
# the original model is functional/sequential), inputs should be set.
node._set_inputs(input_signature)
if isinstance(node, RevivedLayer):
losses = node._serialized_attributes.get('regularization_losses', [])
for loss in losses:
@ -1122,20 +1158,25 @@ class KerasObjectLoader(load.Loader):
node.activity_regularizer = getattr(node.keras_api,
'activity_regularizer_fn', None)
if isinstance(node, RevivedModel):
# Since this revived object is technically a subclassed model (even if
# the original model is functional/sequential), inputs should be set.
input_signature = (
node.keras_api.call_and_return_conditional_losses.input_signature[0]
)
node._set_inputs(input_signature)
# Now that the node object has been fully loaded and restored from the,
# checkpoint, the object no longer needs to track objects added from
# SerializedAttributes. (Note that saving a training checkpoint still
# functions correctly, because layers and variables are tracked
# separately by the Layer object.)
# TODO(kathywu): Instead of outright deleting these nodes (which would
# make restoring from a different checkpoint tricky), mark them as extra
# dependencies that are OK to overwrite.
for name in PUBLIC_ATTRIBUTES:
delete_tracking(node, name)
# pylint: enable=protected-access
def _recreate_base_user_object(self, proto):
revived_classes = {
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
'_tf_keras_model': (RevivedModel, training_lib.Model)
'_tf_keras_model': (RevivedModel, training_lib.Model),
'_tf_keras_sequential': (RevivedSequential, models_lib.Sequential)
}
parent_classes = revived_classes.get(proto.identifier, None)
@ -1193,9 +1234,9 @@ class RevivedLayer(object):
def _revive_setter(self, name, value):
"""Reattaches attributes from the SavedModel to the newly revived object."""
if (name in CommonEndpoints.all_functions or
name in CommonEndpoints.all_checkpointable_objects or
name == _KERAS_ATTR):
if name in PUBLIC_ATTRIBUTES:
if isinstance(value, trackable.Trackable):
self._track_trackable(value, name=name)
self._serialized_attributes[name] = value
else:
setattr(self, name, value)
@ -1258,7 +1299,50 @@ class RevivedModel(RevivedNetwork):
revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata)
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
if 'training_config' in metadata:
revived_obj._training_config = metadata['training_config'] # pylint:disable=protected-access
revived_obj._training_config = metadata.get('training_config') # pylint:disable=protected-access
return revived_obj
class RevivedSequential(RevivedModel):
"""Keras sequential model loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Create revived Sequential model from SavedModel metadata."""
revived_obj = super(RevivedSequential, cls)._init_from_metadata(metadata)
return revived_obj
def call(self, *args, **kwargs):
return models_lib.Sequential.call(self, *args, **kwargs)
def save(model, filepath, overwrite, include_optimizer):
"""Saves a model as a SavedModel to the filepath.
Args:
model: Keras model instance to be saved.
filepath: String path to save the model.
overwrite: whether to overwrite the existing filepath.
include_optimizer: If True, save the model's optimizer state.
Raises:
ValueError: if the model's inputs have not been defined.
"""
# If file exists and should not be overwritten.
if not overwrite and os.path.exists(filepath):
proceed = ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
if _should_skip_serialization(model):
saving_utils.raise_model_input_error(model)
if not include_optimizer:
orig_optimizer = model.optimizer
model.optimizer = None
save_lib.save(model, filepath)
if not include_optimizer:
model.optimizer = orig_optimizer

View File

@ -703,10 +703,6 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.evaluate(getattr(loaded.keras_api, attr)))
self.assertLen(loaded.regularization_losses, 1)
expected_layers = len(model.layers)
if testing_utils.get_model_type() == 'sequential':
# The autogenerated Input layer is hidden in the model.layers list,
# but included in the loaded sub-layers.
expected_layers += 1
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
input_arr = array_ops.ones((4, 3))
training_bool = constant_op.constant(False)
@ -718,5 +714,39 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(self.evaluate(model(input_arr)),
self.evaluate(loaded(*call_args)))
@keras_parameterized.run_with_all_model_types
def test_compiled_model(self):
input_arr = np.random.random((1, 3))
target_arr = np.random.random((1, 4))
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
expected_predict = model.predict(input_arr)
# Compile and save model.
model.compile('rmsprop', 'mse')
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
# TODO(b/134519980): Issue with model.fit if the model call function uses
# a tf.function (Graph mode only).
with context.eager_mode():
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
actual_predict = loaded.predict(input_arr)
self.assertAllClose(expected_predict, actual_predict)
loss_before = loaded.evaluate(input_arr, target_arr)
loaded.fit(input_arr, target_arr)
loss_after = loaded.evaluate(input_arr, target_arr)
self.assertLess(loss_after, loss_before)
predict = loaded.predict(input_arr)
ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
loaded.save_weights(ckpt_path)
# Ensure that the checkpoint is compatible with the original model.
model.load_weights(ckpt_path)
self.assertAllClose(predict, model.predict(input_arr))
if __name__ == '__main__':
test.main()

View File

@ -18,11 +18,14 @@ from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -94,6 +97,14 @@ def model_input_signature(model):
return [input_specs]
def raise_model_input_error(model):
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' .fit() or .predict(). To manually set the shapes, call '
'model._set_inputs(inputs).'.format(model))
def trace_model_call(model, input_signature=None):
"""Trace the model call to create a tf.function for exporting a Keras model.
@ -116,11 +127,7 @@ def trace_model_call(model, input_signature=None):
input_signature = model_input_signature(model)
if input_signature is None:
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' .fit() or .predict(). To manually set the shapes, call '
'model._set_inputs(inputs).'.format(model))
raise_model_input_error(model)
# TODO(mdan): Should the model's call be autographed by default?
@def_function.function(input_signature=input_signature, autograph=False)
@ -190,3 +197,43 @@ def model_metadata(model, include_optimizer=True, require_config=True):
'config': model.optimizer.get_config()}
metadata['training_config']['optimizer_config'] = optimizer_config
return metadata
def should_overwrite(filepath, overwrite):
"""Returns whether the filepath should be overwritten."""
# If file exists and should not be overwritten.
if not overwrite and os.path.isfile(filepath):
return ask_to_proceed_with_overwrite(filepath)
return True
def compile_args_from_training_config(training_config, custom_objects=None):
"""Return model.compile arguments from training config."""
if custom_objects is None:
custom_objects = {}
optimizer_config = training_config['optimizer_config']
optimizer = optimizers.deserialize(
optimizer_config, custom_objects=custom_objects)
# Recover loss functions and metrics.
loss_config = training_config['loss'] # Deserialize loss class.
if isinstance(loss_config, dict) and 'class_name' in loss_config:
loss_config = losses.get(loss_config)
loss = nest.map_structure(
lambda obj: custom_objects.get(obj, obj), loss_config)
metrics = nest.map_structure(
lambda obj: custom_objects.get(obj, obj), training_config['metrics'])
weighted_metrics = nest.map_structure(
lambda obj: custom_objects.get(obj, obj),
training_config.get('weighted_metrics', None))
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
return dict(
optimizer=optimizer,
loss=loss,
metrics=metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode)

View File

@ -81,13 +81,7 @@ class AutoTrackable(base.Trackable):
def __delattr__(self, name):
self._maybe_initialize_trackable()
if name in self._unconditional_dependency_names:
del self._unconditional_dependency_names[name]
for index, (dep_name, _) in enumerate(
self._unconditional_checkpoint_dependencies):
if dep_name == name:
del self._unconditional_checkpoint_dependencies[index]
break
delete_tracking(self, name)
super(AutoTrackable, self).__delattr__(name)
def _no_dependency(self, value):
@ -110,6 +104,19 @@ class AutoTrackable(base.Trackable):
return functions
def delete_tracking(obj, name):
"""Removes the tracking of name from object."""
# pylint: disable=protected-access
if name in obj._unconditional_dependency_names:
del obj._unconditional_dependency_names[name]
for index, (dep_name, _) in enumerate(
obj._unconditional_checkpoint_dependencies):
if dep_name == name:
del obj._unconditional_checkpoint_dependencies[index]
break
# pylint: enable=protected-access
class ResourceTracker(object):
"""An object that tracks a list of resources."""