diff --git a/RELEASE.md b/RELEASE.md index e34dedd12a3..b1847b7d587 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -32,6 +32,8 @@ * `tf.keras`: * Improvements to Keras preprocessing layers: * Discretization combiner implemented, with additional arg `epsilon`. + * Improvements to model saving/loading: + * `model.load_weights` now accepts paths to saved models. * `tf.data`: * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index e6b7dc5ac20..49e3fcfb178 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -83,6 +83,8 @@ py_library( "//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/ops/ragged:ragged_util", "//tensorflow/python/profiler:trace", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", "//tensorflow/python/tpu:tpu_lib", "//tensorflow/python/training/tracking:data_structures", "//tensorflow/tools/docs:doc_controls", diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 9dcac8ca9f3..91c1182eb5c 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -55,6 +55,7 @@ from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import save +from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.utils import generic_utils @@ -72,6 +73,8 @@ from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import trace +from tensorflow.python.saved_model import constants as sm_constants +from tensorflow.python.saved_model import loader_impl as sm_loader from tensorflow.python.training import checkpoint_management from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training.tracking import base as trackable @@ -2114,7 +2117,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): """ self._assert_weights_created() filepath = path_to_string(filepath) - filepath_is_h5 = _is_hdf5_filepath(filepath) + filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) if save_format is None: if filepath_is_h5: save_format = 'h5' @@ -2202,7 +2205,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): Arguments: filepath: String, path to the weights file to load. For weight files in TensorFlow format, this is the file prefix (the same as was passed - to `save_weights`). + to `save_weights`). This can also be a path to a SavedModel + saved from `model.save`. by_name: Boolean, whether to load weights by name or by topological order. Only topological loading is supported for weight files in TensorFlow format. @@ -2229,7 +2233,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): """ if dist_utils.is_tpu_strategy(self._distribution_strategy): if (self._distribution_strategy.extended.steps_per_run > 1 and - (not _is_hdf5_filepath(filepath))): + (not saving_utils.is_hdf5_filepath(filepath))): raise ValueError('Load weights is not yet supported with TPUStrategy ' 'with steps_per_run greater than 1.') if skip_mismatch and not by_name: @@ -2237,16 +2241,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): 'When calling model.load_weights, skip_mismatch can only be set to ' 'True when by_name is True.') - filepath = path_to_string(filepath) - if _is_hdf5_filepath(filepath): - save_format = 'h5' - else: - try: - py_checkpoint_reader.NewCheckpointReader(filepath) - save_format = 'tf' - except errors_impl.DataLossError: - # The checkpoint is not readable in TensorFlow format. Try HDF5. - save_format = 'h5' + filepath, save_format = _detect_save_format(filepath) if save_format == 'tf': status = self._trackable_saver.restore(filepath, options) if by_name: @@ -2851,6 +2846,40 @@ def _disallow_inside_tf_function(method_name): raise RuntimeError(error_msg) -def _is_hdf5_filepath(filepath): - return (filepath.endswith('.h5') or filepath.endswith('.keras') or - filepath.endswith('.hdf5')) +def _detect_save_format(filepath): + """Returns path to weights file and save format.""" + + filepath = path_to_string(filepath) + if saving_utils.is_hdf5_filepath(filepath): + return filepath, 'h5' + + # Filepath could be a TensorFlow checkpoint file prefix or SavedModel + # directory. It's possible for filepath to be both a prefix and directory. + # Prioritize checkpoint over SavedModel. + if _is_readable_tf_checkpoint(filepath): + save_format = 'tf' + elif sm_loader.contains_saved_model(filepath): + ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY, + sm_constants.VARIABLES_FILENAME) + if _is_readable_tf_checkpoint(ckpt_path): + filepath = ckpt_path + save_format = 'tf' + else: + raise ValueError('Unable to load weights. filepath {} appears to be a ' + 'SavedModel directory, but checkpoint either doesn\'t ' + 'exist, or is incorrectly formatted.'.format(filepath)) + else: + # Not a TensorFlow checkpoint. This filepath is likely an H5 file that + # doesn't have the hdf5/keras extensions. + save_format = 'h5' + return filepath, save_format + + +def _is_readable_tf_checkpoint(filepath): + try: + py_checkpoint_reader.NewCheckpointReader(filepath) + return True + except errors_impl.DataLossError: + # The checkpoint is not readable in TensorFlow format. + return False + diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 0faafa6992c..576e8c8469c 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -55,6 +55,7 @@ from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.mixed_precision import loss_scale_optimizer from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils @@ -229,7 +230,7 @@ class Model(training_lib.Model): """ if distributed_training_utils.is_tpu_strategy(self._distribution_strategy): if (self._distribution_strategy.extended.steps_per_run > 1 and - (not training_lib._is_hdf5_filepath(filepath))): # pylint: disable=protected-access + (not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access raise ValueError('Load weights is not yet supported with TPUStrategy ' 'with steps_per_run greater than 1.') return super(Model, self).load_weights(filepath, by_name, skip_mismatch) diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index 2b1d7b507f3..4a4c345d3ec 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import six from tensorflow.python import tf2 from tensorflow.python.keras.saving import hdf5_format +from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving.saved_model import load as saved_model_load from tensorflow.python.keras.saving.saved_model import load_context from tensorflow.python.keras.saving.saved_model import save as saved_model_save @@ -39,12 +39,6 @@ except ImportError: h5py = None # pylint: enable=g-import-not-at-top -_HDF5_EXTENSIONS = ['.h5', '.hdf5', '.keras'] - - -# TODO(kathywu): Remove this when Keras SavedModel is not experimental. -_KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = True - @keras_export('keras.models.save_model') def save_model(model, @@ -140,7 +134,7 @@ def save_model(model, if (save_format == 'h5' or (h5py is not None and isinstance(filepath, h5py.File)) or - os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS): + saving_utils.is_hdf5_filepath(filepath)): # TODO(b/130258301): add utility method for detecting model type. if (not model._is_graph_network and # pylint:disable=protected-access not isinstance(model, sequential.Sequential)): diff --git a/tensorflow/python/keras/saving/save_weights_test.py b/tensorflow/python/keras/saving/save_weights_test.py index 229a891b2b7..1f5fbb4542f 100644 --- a/tensorflow/python/keras/saving/save_weights_test.py +++ b/tensorflow/python/keras/saving/save_weights_test.py @@ -54,11 +54,14 @@ except ImportError: @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): + def _save_model_dir(self, dirname='saved_model'): + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + return os.path.join(temp_dir, dirname) + @keras_parameterized.run_with_all_weight_formats def test_weight_loading(self): - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - saved_model_dir = os.path.join(temp_dir, 'saved_model') + saved_model_dir = self._save_model_dir() save_format = testing_utils.get_save_format() with self.cached_session(): a = keras.layers.Input(shape=(2,)) @@ -213,9 +216,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): if h5py is None: return - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - h5_path = os.path.join(temp_dir, 'test.h5') + h5_path = self._save_model_dir('test.h5') num_hidden = 5 input_dim = 3 @@ -244,9 +245,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): exclude_formats=['tf_no_traces']) def test_nested_model_weight_loading(self): save_format = testing_utils.get_save_format() - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - saved_model_dir = os.path.join(temp_dir, 'saved_model') + saved_model_dir = self._save_model_dir() batch_size = 5 shape = (None, None, 3) @@ -284,9 +283,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): if h5py is None: return - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - h5_path = os.path.join(temp_dir, 'test.h5') + h5_path = self._save_model_dir('test.h5') num_hidden = 5 input_dim = 3 @@ -326,9 +323,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): if h5py is None: return - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - h5_path = os.path.join(temp_dir, 'test.h5') + h5_path = self._save_model_dir('test.h5') num_hidden = 5 input_dim = 3 @@ -367,6 +362,32 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose([3.5] * num_classes, keras.backend.get_value(model.layers[1].bias)) + @keras_parameterized.run_with_all_saved_model_formats( + exclude_formats=['tf_no_traces']) + @keras_parameterized.run_with_all_model_types + def test_load_weights_from_saved_model(self): + save_path = self._save_model_dir() + save_format = testing_utils.get_save_format() + + if save_format == 'h5' and testing_utils.get_model_type() == 'subclass': + # TODO(b/173646281): HDF5 format currently does not allow saving + # subclassed models. + return + + with self.cached_session(): + model = testing_utils.get_small_mlp(1, 4, input_dim=3) + data = np.random.random((1, 3)) + labels = np.random.random((1, 4)) + model.compile(loss='mse', optimizer='rmsprop') + model.fit(data, labels) + model.save(save_path, save_format=save_format) + new_model = testing_utils.get_small_mlp(1, 4, input_dim=3) + if testing_utils.get_model_type() == 'subclass': + # Call on test data to build the model. + new_model.predict(data) + new_model.load_weights(save_path) + self.assertAllClose(model.weights, new_model.weights) + class SubclassedModel(training.Model): diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index e459d174fa9..fc092dfa9cf 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -321,3 +321,8 @@ def try_build_compiled_arguments(model): 'Compiled the loaded model, but the compiled metrics have yet to ' 'be built. `model.compile_metrics` will be empty until you train ' 'or evaluate the model.') + + +def is_hdf5_filepath(filepath): + return (filepath.endswith('.h5') or filepath.endswith('.keras') or + filepath.endswith('.hdf5'))