diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_test.py index 719eb08e28f..7e441e2ee09 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_test.py @@ -145,8 +145,11 @@ class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase, # ensure every worker has a unique path. Note that in normal use case the # saving_filepath will be the same for all workers, but we use different # ones here just to test out chief saves checkpoint but non-chief doesn't. + + # TODO(b/134551335): Must save to hdf5 until bug with copying + # MirroredVariables is resolved. saving_filepath = os.path.join( - test_obj.get_temp_dir(), 'checkpoint_%s_%d' % + test_obj.get_temp_dir(), 'checkpoint_%s_%d.h5' % (test_base.get_task_type(), test_base.get_task_index())) # The saving_filepath shouldn't exist at the beginning (as it's unique). diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 7d84bc2a440..67359a0aff4 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -361,3 +361,7 @@ class Sequential(training.Model): if self.layers and hasattr(self.layers[0], 'input_spec'): return self.layers[0].input_spec return None + + @property + def _object_identifier(self): + return '_tf_keras_sequential' diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py index 86ccbfe8160..5295111461a 100644 --- a/tensorflow/python/keras/integration_test.py +++ b/tensorflow/python/keras/integration_test.py @@ -40,7 +40,6 @@ class KerasIntegrationTest(keras_parameterized.TestCase): fpath = os.path.join(self.temp_dir, 'test_model_%s' % (random.randint(0, 1e7),)) if context.executing_eagerly(): - keras.saving.save._KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = False save_format = 'tf' else: if (not isinstance(model, keras.Sequential) and @@ -155,8 +154,19 @@ class SequentialIntegrationTest(KerasIntegrationTest): validation_data=(x_train, y_train), verbose=2) model = self._save_and_reload_model(model) + + # TODO(b/134537740): model.pop doesn't update model outputs properly when + # model.outputs is already defined, so just set to `None` for now. + model.inputs = None + model.outputs = None + model.pop() model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax')) + + # TODO(b/134523282): There is an bug with Sequential models, so the model + # must be marked as compiled=False to ensure the next compile goes through. + model._is_compiled = False + model.compile( loss='categorical_crossentropy', optimizer=keras.optimizer_v2.adam.Adam(0.005), diff --git a/tensorflow/python/keras/saving/hdf5_format.py b/tensorflow/python/keras/saving/hdf5_format.py index 630dc6b9b6a..ec329cd0bca 100644 --- a/tensorflow/python/keras/saving/hdf5_format.py +++ b/tensorflow/python/keras/saving/hdf5_format.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.keras import backend as K -from tensorflow.python.keras import losses from tensorflow.python.keras import optimizers from tensorflow.python.keras.saving import model_config as model_config_lib from tensorflow.python.keras.saving import saving_utils @@ -146,31 +145,6 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint if not custom_objects: custom_objects = {} - def convert_custom_objects(obj): - """Handles custom object lookup. - - Arguments: - obj: object, dict, or list. - - Returns: - The same structure, where occurrences - of a custom object name have been replaced - with the custom object. - """ - if isinstance(obj, list): - deserialized = [] - for value in obj: - deserialized.append(convert_custom_objects(value)) - return deserialized - if isinstance(obj, dict): - deserialized = {} - for key, value in obj.items(): - deserialized[key] = convert_custom_objects(value) - return deserialized - if obj in custom_objects: - return custom_objects[obj] - return obj - opened_new_file = not isinstance(filepath, h5py.File) if opened_new_file: f = h5py.File(filepath, mode='r') @@ -198,29 +172,10 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint 'the model was *not* compiled. Compile it manually.') return model training_config = json.loads(training_config.decode('utf-8')) - optimizer_config = training_config['optimizer_config'] - optimizer = optimizers.deserialize( - optimizer_config, custom_objects=custom_objects) - - # Recover loss functions and metrics. - loss_config = training_config['loss'] # Deserialize loss class. - if isinstance(loss_config, dict) and 'class_name' in loss_config: - loss_config = losses.get(loss_config) - loss = convert_custom_objects(loss_config) - metrics = convert_custom_objects(training_config['metrics']) - weighted_metrics = convert_custom_objects( - training_config.get('weighted_metrics', None)) - sample_weight_mode = training_config['sample_weight_mode'] - loss_weights = training_config['loss_weights'] # Compile model. - model.compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - weighted_metrics=weighted_metrics, - loss_weights=loss_weights, - sample_weight_mode=sample_weight_mode) + model.compile(**saving_utils.compile_args_from_training_config( + training_config, custom_objects)) # Set optimizer weights. if 'optimizer_weights' in f: diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index 068102cd2fc..d8d81557970 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -23,7 +23,6 @@ import os import six from tensorflow.python import tf2 -from tensorflow.python.framework import ops from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import saved_model from tensorflow.python.saved_model import loader_impl @@ -77,32 +76,16 @@ def save_model(model, location, or instead ask the user with a manual prompt. include_optimizer: If True, save optimizer's state together. save_format: Either 'tf' or 'h5', indicating whether to save the model - to Tensorflow SavedModel or HDF5. The 'tf' option is currently disabled, - and will be enabled when Keras SavedModel export is no longer - experimental. (The experimental function is - tf.keras.experimental.export_saved_model). + to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 'h5' + in TF 1.X. Raises: ImportError: If save format is hdf5, and h5py is not available. """ from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top - if (not tf2.enabled() and - not ops.executing_eagerly_outside_functions() - and save_format == 'tf'): - raise NotImplementedError( - 'Saving the model as SavedModel is not supported in TensorFlow 1.X' - 'graph mode. Please enable eager execution or use the "h5" save format.' - ) - - if _KERAS_SAVED_MODEL_STILL_EXPERIMENTAL and save_format == 'tf': - raise NotImplementedError( - 'Saving the model as SavedModel is still in experimental stages. ' - 'Please use tf.keras.experimental.export_saved_model, or use ' - 'save_format="h5" to save to HDF5.') - - # TODO(kathywu): Remove this when Keras SavedModel is not experimental. - save_format = 'h5' + default_format = 'tf' if tf2.enabled() else 'h5' + save_format = save_format or default_format if (save_format == 'h5' or (h5py is not None and isinstance(filepath, h5py.File)) or @@ -119,7 +102,8 @@ def save_model(model, 'or using `save_weights`.') hdf5_format.save_model_to_hdf5( model, filepath, overwrite, include_optimizer) - return + else: + saved_model.save(model, filepath, overwrite, include_optimizer) @keras_export('keras.models.load_model') @@ -148,14 +132,13 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= ImportError: if loading from an hdf5 file and h5py is not available. IOError: In case of an invalid savefile. """ - if not tf2.enabled() or ( - h5py is not None and ( - isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): + if (h5py is not None and ( + isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile) if isinstance(filepath, six.string_types): loader_impl.parse_saved_model(filepath) - return saved_model.load_from_saved_model_v2(filepath) + return saved_model.load_from_saved_model_v2(filepath, compile) raise IOError( 'Unable to load model. Filepath is not an hdf5 file (or h5py is not ' diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 0ac1a172d7d..c600cf6defd 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.saving import save from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader_impl try: import h5py # pylint:disable=g-import-not-at-top @@ -43,28 +44,35 @@ class TestSaveModel(test.TestCase): 'Model saved at path {} is not a valid hdf5 file.' .format(path)) + def assert_saved_model(self, path): + loader_impl.parse_saved_model(path) + @test_util.run_v2_only def test_save_format_defaults(self): path = os.path.join(self.get_temp_dir(), 'model_path') - - # The default is currently HDF5 no matter what the filepath is. save.save_model(self.model, path) - self.assert_h5_format(path) + self.assert_saved_model(path) @test_util.run_v2_only def test_save_hdf5(self): path = os.path.join(self.get_temp_dir(), 'model') save.save_model(self.model, path, save_format='h5') - self.assert_h5_format(path) + with self.assertRaisesRegexp( + NotImplementedError, + 'requires the model to be a Functional model or a Sequential model.'): + save.save_model(self.subclassed_model, path, save_format='h5') @test_util.run_v2_only def test_save_tf(self): path = os.path.join(self.get_temp_dir(), 'model') - with self.assertRaisesRegexp( - NotImplementedError, - 'Saving the model as SavedModel is still in experimental stages.'): - save.save_model(self.model, path, save_format='tf') + save.save_model(self.model, path, save_format='tf') + self.assert_saved_model(path) + with self.assertRaisesRegexp(ValueError, 'input shapes have not been set'): + save.save_model(self.subclassed_model, path, save_format='tf') + self.subclassed_model.predict(np.random.random((3, 5))) + save.save_model(self.subclassed_model, path, save_format='tf') + self.assert_saved_model(path) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/saving/saved_model.py b/tensorflow/python/keras/saving/saved_model.py index a97420569c8..92e53882b57 100644 --- a/tensorflow/python/keras/saving/saved_model.py +++ b/tensorflow/python/keras/saving/saved_model.py @@ -38,6 +38,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving import model_from_json from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.utils import mode_keys +from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.lib.io import file_io from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -54,6 +55,7 @@ from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils from tensorflow.python.training.tracking.tracking import AutoTrackable +from tensorflow.python.training.tracking.tracking import delete_tracking from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util.lazy_loader import LazyLoader @@ -81,6 +83,10 @@ sequential = LazyLoader( training_lib = LazyLoader( "training_lib", globals(), "tensorflow.python.keras.engine.training") +input_layer = LazyLoader( + "input_layer", globals(), + "tensorflow.python.keras.engine.input_layer") + # pylint:enable=g-inconsistent-quotes @@ -713,7 +719,7 @@ def serialize_all_attributes(layer, serialization_cache): except (ValueError, TypeError) as e: logging.warning('Skipping full serialization of object {}, because an ' 'error occurred while tracing layer functions. Error ' - 'message: {}'.format(layer, e.message)) + 'message: {}'.format(layer, e)) else: # Add checkpointable objects and functions to the SerializedAttribute object # only if all functions are successfully traced. @@ -743,10 +749,6 @@ def _should_skip_serialization(layer): else: return False else: - if not layer.input_spec: - logging.warning('Skipping full serialization of Keras layer {}, because ' - 'it does not have an input spec defined.'.format(layer)) - return True if not layer.built: logging.warning('Skipping full serialization of Keras layer {}, because ' 'it is not built.'.format(layer)) @@ -771,8 +773,7 @@ def _wrap_layer_objects(layer, serialization_cache): # First, generate list of all regularization losses in this layer and # sublayers. regularization_losses = layer._callable_losses[:] # pylint: disable=protected-access - for child_layer in ( - trackable_layer_utils.filter_empty_layer_containers(layer._layers)): # pylint: disable=protected-access + for child_layer in _list_all_layers(layer): regularization_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access # Next, wrap all loss functions as tf.functions. Use the serialization cache # to store already-wrapped functions. @@ -791,9 +792,7 @@ def _wrap_layer_objects(layer, serialization_cache): layer.trainable_variables), non_trainable_variables=data_structures.ListWrapper( layer.non_trainable_variables), - layers=data_structures.ListWrapper( - trackable_layer_utils.filter_empty_layer_containers( - layer._layers)), # pylint: disable=protected-access + layers=data_structures.ListWrapper(_list_all_layers(layer)), metrics=data_structures.ListWrapper(layer.metrics), regularization_losses=data_structures.ListWrapper( wrapped_loss_functions)) @@ -857,6 +856,13 @@ def _wrap_layer_functions(layer, serialization_cache, return fns +def _list_all_layers(obj): + if isinstance(obj, training_lib.Model): + return obj.layers + else: + return trackable_layer_utils.filter_empty_layer_containers(obj._layers) # pylint: disable=protected-access + + def _replace_child_layer_functions(layer, serialization_cache): """Replaces functions in the children layers with wrapped tf.functions. @@ -882,8 +888,7 @@ def _replace_child_layer_functions(layer, serialization_cache): } """ original_attrs = {} - for child_layer in trackable_layer_utils.filter_empty_layer_containers( - layer._layers): # pylint: disable=protected-access + for child_layer in _list_all_layers(layer): # Save symbolic layer losses, which will be restored to maintain the same # state. original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access @@ -934,14 +939,15 @@ def _use_wrapped_call(layer, call_fn): function that calls call_fn and returns the outputs. Losses returned by call_fn are added to the layer losses. """ - def wrapped_call(inputs, *args, **kwargs): + # TODO(kathywu): Support mask argument and multi-input call functions. + def wrapped_call(inputs, **kwargs): """Returns the outputs from the call_fn, and adds the losses.""" if layer._expects_training_arg: # pylint: disable=protected-access training = kwargs.pop('training', None) if training is None: training = K.learning_phase() training = math_ops.cast(training, dtypes.bool) - outputs, losses = call_fn(inputs, training=training, *args, **kwargs) + outputs, losses = call_fn(inputs, training=training) else: outputs, losses = call_fn(inputs) layer.add_loss(losses, inputs) @@ -963,7 +969,7 @@ def _wrap_call_and_conditional_losses(layer): activity regularizer """ if isinstance(layer, RevivedLayer): - return layer.call_and_return_conditional_losses + return layer.keras_api.call_and_return_conditional_losses if (isinstance(layer.call, def_function.Function) and layer.call.input_signature is not None): @@ -972,7 +978,7 @@ def _wrap_call_and_conditional_losses(layer): if (isinstance(layer, training_lib.Model) and saving_utils.model_input_signature(layer) is not None): input_signature = saving_utils.model_input_signature(layer) - else: + elif layer.input_spec is not None: input_signature = [nest.map_structure( lambda x: input_spec.to_tensor_spec(x, layer.dtype), layer.input_spec)] @@ -981,6 +987,8 @@ def _wrap_call_and_conditional_losses(layer): if spec.shape == tensor_shape.TensorShape(None): input_signature = None break + else: + input_signature = None if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access input_signature.append( @@ -1007,7 +1015,7 @@ def _wrap_call_and_conditional_losses(layer): def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): """Returns a function that returns only call function outputs.""" if isinstance(layer, RevivedLayer): - return layer._original_call # pylint: disable=protected-access + return layer.keras_api.__call__ # pylint: disable=protected-access if layer._expects_training_arg: # pylint: disable=protected-access def call(inputs, training): return call_and_return_conditional_losses(inputs, training)[0] @@ -1076,7 +1084,7 @@ def _wrap_activity_regularizer(layer): input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())]) -def load_from_saved_model_v2(path): +def load_from_saved_model_v2(path, compile=True): # pylint: disable=redefined-builtin """Loads Keras objects from a SavedModel. Any Keras layer or model saved to the SavedModel will be loaded back @@ -1092,13 +1100,27 @@ def load_from_saved_model_v2(path): Args: path: Path to SavedModel. + compile: If true, compile the model after loading it. Returns: Object loaded from SavedModel. """ # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints - return load.load_internal(path, loader_cls=KerasObjectLoader) + model = load.load_internal(path, loader_cls=KerasObjectLoader) + + if isinstance(model, RevivedModel) and compile: + # TODO(kathywu): Use compiled objects from SavedModel, instead of + # creating new objects from the training config. + if model._training_config is not None: # pylint: disable=protected-access + model.compile(**saving_utils.compile_args_from_training_config( + model._training_config)) # pylint: disable=protected-access + + return model + +PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union( + CommonEndpoints.all_checkpointable_objects) +PUBLIC_ATTRIBUTES.add(_KERAS_ATTR) class KerasObjectLoader(load.Loader): @@ -1111,6 +1133,20 @@ class KerasObjectLoader(load.Loader): def _finalize(self): # pylint: disable=protected-access for node in self._nodes: + if isinstance(node, RevivedModel): + input_signature = ( + node.keras_api.call_and_return_conditional_losses.input_signature[0] + ) + if isinstance(node, RevivedSequential): + with trackable.no_automatic_dependency_tracking_scope(node): + node._layers = [] + for layer in node.keras_api.layers: + node.add(layer) + + if not node.inputs: + # Since this revived object is technically a subclassed model (even if + # the original model is functional/sequential), inputs should be set. + node._set_inputs(input_signature) if isinstance(node, RevivedLayer): losses = node._serialized_attributes.get('regularization_losses', []) for loss in losses: @@ -1122,20 +1158,25 @@ class KerasObjectLoader(load.Loader): node.activity_regularizer = getattr(node.keras_api, 'activity_regularizer_fn', None) - if isinstance(node, RevivedModel): - # Since this revived object is technically a subclassed model (even if - # the original model is functional/sequential), inputs should be set. - input_signature = ( - node.keras_api.call_and_return_conditional_losses.input_signature[0] - ) - node._set_inputs(input_signature) + # Now that the node object has been fully loaded and restored from the, + # checkpoint, the object no longer needs to track objects added from + # SerializedAttributes. (Note that saving a training checkpoint still + # functions correctly, because layers and variables are tracked + # separately by the Layer object.) + # TODO(kathywu): Instead of outright deleting these nodes (which would + # make restoring from a different checkpoint tricky), mark them as extra + # dependencies that are OK to overwrite. + for name in PUBLIC_ATTRIBUTES: + delete_tracking(node, name) + # pylint: enable=protected-access def _recreate_base_user_object(self, proto): revived_classes = { '_tf_keras_layer': (RevivedLayer, base_layer.Layer), '_tf_keras_network': (RevivedNetwork, network_lib.Network), - '_tf_keras_model': (RevivedModel, training_lib.Model) + '_tf_keras_model': (RevivedModel, training_lib.Model), + '_tf_keras_sequential': (RevivedSequential, models_lib.Sequential) } parent_classes = revived_classes.get(proto.identifier, None) @@ -1193,9 +1234,9 @@ class RevivedLayer(object): def _revive_setter(self, name, value): """Reattaches attributes from the SavedModel to the newly revived object.""" - if (name in CommonEndpoints.all_functions or - name in CommonEndpoints.all_checkpointable_objects or - name == _KERAS_ATTR): + if name in PUBLIC_ATTRIBUTES: + if isinstance(value, trackable.Trackable): + self._track_trackable(value, name=name) self._serialized_attributes[name] = value else: setattr(self, name, value) @@ -1258,7 +1299,50 @@ class RevivedModel(RevivedNetwork): revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata) with trackable.no_automatic_dependency_tracking_scope(revived_obj): - if 'training_config' in metadata: - revived_obj._training_config = metadata['training_config'] # pylint:disable=protected-access + revived_obj._training_config = metadata.get('training_config') # pylint:disable=protected-access return revived_obj + + +class RevivedSequential(RevivedModel): + """Keras sequential model loaded from a SavedModel.""" + + @classmethod + def _init_from_metadata(cls, metadata): + """Create revived Sequential model from SavedModel metadata.""" + revived_obj = super(RevivedSequential, cls)._init_from_metadata(metadata) + return revived_obj + + def call(self, *args, **kwargs): + return models_lib.Sequential.call(self, *args, **kwargs) + + +def save(model, filepath, overwrite, include_optimizer): + """Saves a model as a SavedModel to the filepath. + + Args: + model: Keras model instance to be saved. + filepath: String path to save the model. + overwrite: whether to overwrite the existing filepath. + include_optimizer: If True, save the model's optimizer state. + + Raises: + ValueError: if the model's inputs have not been defined. + """ + # If file exists and should not be overwritten. + if not overwrite and os.path.exists(filepath): + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + if _should_skip_serialization(model): + saving_utils.raise_model_input_error(model) + + if not include_optimizer: + orig_optimizer = model.optimizer + model.optimizer = None + + save_lib.save(model, filepath) + + if not include_optimizer: + model.optimizer = orig_optimizer diff --git a/tensorflow/python/keras/saving/saved_model_test.py b/tensorflow/python/keras/saving/saved_model_test.py index b42874ca097..f3b6eba96a5 100644 --- a/tensorflow/python/keras/saving/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model_test.py @@ -703,10 +703,6 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.evaluate(getattr(loaded.keras_api, attr))) self.assertLen(loaded.regularization_losses, 1) expected_layers = len(model.layers) - if testing_utils.get_model_type() == 'sequential': - # The autogenerated Input layer is hidden in the model.layers list, - # but included in the loaded sub-layers. - expected_layers += 1 self.assertEqual(expected_layers, len(loaded.keras_api.layers)) input_arr = array_ops.ones((4, 3)) training_bool = constant_op.constant(False) @@ -718,5 +714,39 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllClose(self.evaluate(model(input_arr)), self.evaluate(loaded(*call_args))) + @keras_parameterized.run_with_all_model_types + def test_compiled_model(self): + input_arr = np.random.random((1, 3)) + target_arr = np.random.random((1, 4)) + + model = testing_utils.get_small_mlp(1, 4, input_dim=3) + expected_predict = model.predict(input_arr) + + # Compile and save model. + model.compile('rmsprop', 'mse') + saved_model_dir = self._save_model_dir() + tf_save.save(model, saved_model_dir) + + # TODO(b/134519980): Issue with model.fit if the model call function uses + # a tf.function (Graph mode only). + with context.eager_mode(): + loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir) + actual_predict = loaded.predict(input_arr) + self.assertAllClose(expected_predict, actual_predict) + + loss_before = loaded.evaluate(input_arr, target_arr) + loaded.fit(input_arr, target_arr) + loss_after = loaded.evaluate(input_arr, target_arr) + self.assertLess(loss_after, loss_before) + predict = loaded.predict(input_arr) + + ckpt_path = os.path.join(self.get_temp_dir(), 'weights') + loaded.save_weights(ckpt_path) + + # Ensure that the checkpoint is compatible with the original model. + model.load_weights(ckpt_path) + self.assertAllClose(predict, model.predict(input_arr)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index ea4d60edf5e..866f596fca3 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -18,11 +18,14 @@ from __future__ import division from __future__ import print_function import collections +import os from tensorflow.python.eager import def_function from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K +from tensorflow.python.keras import losses from tensorflow.python.keras import optimizers +from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -94,6 +97,14 @@ def model_input_signature(model): return [input_specs] +def raise_model_input_error(model): + raise ValueError( + 'Model {} cannot be saved because the input shapes have not been ' + 'set. Usually, input shapes are automatically determined from calling' + ' .fit() or .predict(). To manually set the shapes, call ' + 'model._set_inputs(inputs).'.format(model)) + + def trace_model_call(model, input_signature=None): """Trace the model call to create a tf.function for exporting a Keras model. @@ -116,11 +127,7 @@ def trace_model_call(model, input_signature=None): input_signature = model_input_signature(model) if input_signature is None: - raise ValueError( - 'Model {} cannot be saved because the input shapes have not been ' - 'set. Usually, input shapes are automatically determined from calling' - ' .fit() or .predict(). To manually set the shapes, call ' - 'model._set_inputs(inputs).'.format(model)) + raise_model_input_error(model) # TODO(mdan): Should the model's call be autographed by default? @def_function.function(input_signature=input_signature, autograph=False) @@ -190,3 +197,43 @@ def model_metadata(model, include_optimizer=True, require_config=True): 'config': model.optimizer.get_config()} metadata['training_config']['optimizer_config'] = optimizer_config return metadata + + +def should_overwrite(filepath, overwrite): + """Returns whether the filepath should be overwritten.""" + # If file exists and should not be overwritten. + if not overwrite and os.path.isfile(filepath): + return ask_to_proceed_with_overwrite(filepath) + return True + + +def compile_args_from_training_config(training_config, custom_objects=None): + """Return model.compile arguments from training config.""" + if custom_objects is None: + custom_objects = {} + + optimizer_config = training_config['optimizer_config'] + optimizer = optimizers.deserialize( + optimizer_config, custom_objects=custom_objects) + + # Recover loss functions and metrics. + loss_config = training_config['loss'] # Deserialize loss class. + if isinstance(loss_config, dict) and 'class_name' in loss_config: + loss_config = losses.get(loss_config) + loss = nest.map_structure( + lambda obj: custom_objects.get(obj, obj), loss_config) + metrics = nest.map_structure( + lambda obj: custom_objects.get(obj, obj), training_config['metrics']) + weighted_metrics = nest.map_structure( + lambda obj: custom_objects.get(obj, obj), + training_config.get('weighted_metrics', None)) + sample_weight_mode = training_config['sample_weight_mode'] + loss_weights = training_config['loss_weights'] + + return dict( + optimizer=optimizer, + loss=loss, + metrics=metrics, + weighted_metrics=weighted_metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode) diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index 5838566b852..b90f7f2af8d 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -81,13 +81,7 @@ class AutoTrackable(base.Trackable): def __delattr__(self, name): self._maybe_initialize_trackable() - if name in self._unconditional_dependency_names: - del self._unconditional_dependency_names[name] - for index, (dep_name, _) in enumerate( - self._unconditional_checkpoint_dependencies): - if dep_name == name: - del self._unconditional_checkpoint_dependencies[index] - break + delete_tracking(self, name) super(AutoTrackable, self).__delattr__(name) def _no_dependency(self, value): @@ -110,6 +104,19 @@ class AutoTrackable(base.Trackable): return functions +def delete_tracking(obj, name): + """Removes the tracking of name from object.""" + # pylint: disable=protected-access + if name in obj._unconditional_dependency_names: + del obj._unconditional_dependency_names[name] + for index, (dep_name, _) in enumerate( + obj._unconditional_checkpoint_dependencies): + if dep_name == name: + del obj._unconditional_checkpoint_dependencies[index] + break + # pylint: enable=protected-access + + class ResourceTracker(object): """An object that tracks a list of resources."""