Allow weights to be loaded from saved model.

(This change includes modifications to training_v1, which caused the TAP failures from the last CL).

PiperOrigin-RevId: 347838537
Change-Id: Ie05d51602eb770d9703ec9258576d925f139117e
This commit is contained in:
Katherine Wu 2020-12-16 09:36:57 -08:00 committed by TensorFlower Gardener
parent 4f2d645a80
commit 436808c7b6
7 changed files with 94 additions and 40 deletions

View File

@ -32,6 +32,8 @@
* `tf.keras`: * `tf.keras`:
* Improvements to Keras preprocessing layers: * Improvements to Keras preprocessing layers:
* Discretization combiner implemented, with additional arg `epsilon`. * Discretization combiner implemented, with additional arg `epsilon`.
* Improvements to model saving/loading:
* `model.load_weights` now accepts paths to saved models.
* `tf.data`: * `tf.data`:
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used

View File

@ -83,6 +83,8 @@ py_library(
"//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_util", "//tensorflow/python/ops/ragged:ragged_util",
"//tensorflow/python/profiler:trace", "//tensorflow/python/profiler:trace",
"//tensorflow/python/saved_model:constants",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/tpu:tpu_lib", "//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/training/tracking:data_structures", "//tensorflow/python/training/tracking:data_structures",
"//tensorflow/tools/docs:doc_controls", "//tensorflow/tools/docs:doc_controls",

View File

@ -55,6 +55,7 @@ from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso
from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
@ -72,6 +73,8 @@ from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace from tensorflow.python.profiler import trace
from tensorflow.python.saved_model import constants as sm_constants
from tensorflow.python.saved_model import loader_impl as sm_loader
from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
@ -2114,7 +2117,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
""" """
self._assert_weights_created() self._assert_weights_created()
filepath = path_to_string(filepath) filepath = path_to_string(filepath)
filepath_is_h5 = _is_hdf5_filepath(filepath) filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
if save_format is None: if save_format is None:
if filepath_is_h5: if filepath_is_h5:
save_format = 'h5' save_format = 'h5'
@ -2202,7 +2205,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
Arguments: Arguments:
filepath: String, path to the weights file to load. For weight files in filepath: String, path to the weights file to load. For weight files in
TensorFlow format, this is the file prefix (the same as was passed TensorFlow format, this is the file prefix (the same as was passed
to `save_weights`). to `save_weights`). This can also be a path to a SavedModel
saved from `model.save`.
by_name: Boolean, whether to load weights by name or by topological by_name: Boolean, whether to load weights by name or by topological
order. Only topological loading is supported for weight files in order. Only topological loading is supported for weight files in
TensorFlow format. TensorFlow format.
@ -2229,7 +2233,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
""" """
if dist_utils.is_tpu_strategy(self._distribution_strategy): if dist_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and if (self._distribution_strategy.extended.steps_per_run > 1 and
(not _is_hdf5_filepath(filepath))): (not saving_utils.is_hdf5_filepath(filepath))):
raise ValueError('Load weights is not yet supported with TPUStrategy ' raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.') 'with steps_per_run greater than 1.')
if skip_mismatch and not by_name: if skip_mismatch and not by_name:
@ -2237,16 +2241,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
'When calling model.load_weights, skip_mismatch can only be set to ' 'When calling model.load_weights, skip_mismatch can only be set to '
'True when by_name is True.') 'True when by_name is True.')
filepath = path_to_string(filepath) filepath, save_format = _detect_save_format(filepath)
if _is_hdf5_filepath(filepath):
save_format = 'h5'
else:
try:
py_checkpoint_reader.NewCheckpointReader(filepath)
save_format = 'tf'
except errors_impl.DataLossError:
# The checkpoint is not readable in TensorFlow format. Try HDF5.
save_format = 'h5'
if save_format == 'tf': if save_format == 'tf':
status = self._trackable_saver.restore(filepath, options) status = self._trackable_saver.restore(filepath, options)
if by_name: if by_name:
@ -2851,6 +2846,40 @@ def _disallow_inside_tf_function(method_name):
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
def _is_hdf5_filepath(filepath): def _detect_save_format(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or """Returns path to weights file and save format."""
filepath.endswith('.hdf5'))
filepath = path_to_string(filepath)
if saving_utils.is_hdf5_filepath(filepath):
return filepath, 'h5'
# Filepath could be a TensorFlow checkpoint file prefix or SavedModel
# directory. It's possible for filepath to be both a prefix and directory.
# Prioritize checkpoint over SavedModel.
if _is_readable_tf_checkpoint(filepath):
save_format = 'tf'
elif sm_loader.contains_saved_model(filepath):
ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
sm_constants.VARIABLES_FILENAME)
if _is_readable_tf_checkpoint(ckpt_path):
filepath = ckpt_path
save_format = 'tf'
else:
raise ValueError('Unable to load weights. filepath {} appears to be a '
'SavedModel directory, but checkpoint either doesn\'t '
'exist, or is incorrectly formatted.'.format(filepath))
else:
# Not a TensorFlow checkpoint. This filepath is likely an H5 file that
# doesn't have the hdf5/keras extensions.
save_format = 'h5'
return filepath, save_format
def _is_readable_tf_checkpoint(filepath):
try:
py_checkpoint_reader.NewCheckpointReader(filepath)
return True
except errors_impl.DataLossError:
# The checkpoint is not readable in TensorFlow format.
return False

View File

@ -55,6 +55,7 @@ from tensorflow.python.keras.engine import training_utils_v1
from tensorflow.python.keras.mixed_precision import loss_scale_optimizer from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
@ -229,7 +230,7 @@ class Model(training_lib.Model):
""" """
if distributed_training_utils.is_tpu_strategy(self._distribution_strategy): if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and if (self._distribution_strategy.extended.steps_per_run > 1 and
(not training_lib._is_hdf5_filepath(filepath))): # pylint: disable=protected-access (not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access
raise ValueError('Load weights is not yet supported with TPUStrategy ' raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.') 'with steps_per_run greater than 1.')
return super(Model, self).load_weights(filepath, by_name, skip_mismatch) return super(Model, self).load_weights(filepath, by_name, skip_mismatch)

View File

@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import six import six
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import load as saved_model_load from tensorflow.python.keras.saving.saved_model import load as saved_model_load
from tensorflow.python.keras.saving.saved_model import load_context from tensorflow.python.keras.saving.saved_model import load_context
from tensorflow.python.keras.saving.saved_model import save as saved_model_save from tensorflow.python.keras.saving.saved_model import save as saved_model_save
@ -39,12 +39,6 @@ except ImportError:
h5py = None h5py = None
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
_HDF5_EXTENSIONS = ['.h5', '.hdf5', '.keras']
# TODO(kathywu): Remove this when Keras SavedModel is not experimental.
_KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = True
@keras_export('keras.models.save_model') @keras_export('keras.models.save_model')
def save_model(model, def save_model(model,
@ -140,7 +134,7 @@ def save_model(model,
if (save_format == 'h5' or if (save_format == 'h5' or
(h5py is not None and isinstance(filepath, h5py.File)) or (h5py is not None and isinstance(filepath, h5py.File)) or
os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS): saving_utils.is_hdf5_filepath(filepath)):
# TODO(b/130258301): add utility method for detecting model type. # TODO(b/130258301): add utility method for detecting model type.
if (not model._is_graph_network and # pylint:disable=protected-access if (not model._is_graph_network and # pylint:disable=protected-access
not isinstance(model, sequential.Sequential)): not isinstance(model, sequential.Sequential)):

View File

@ -54,11 +54,14 @@ except ImportError:
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname)
@keras_parameterized.run_with_all_weight_formats @keras_parameterized.run_with_all_weight_formats
def test_weight_loading(self): def test_weight_loading(self):
temp_dir = self.get_temp_dir() saved_model_dir = self._save_model_dir()
self.addCleanup(shutil.rmtree, temp_dir)
saved_model_dir = os.path.join(temp_dir, 'saved_model')
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
with self.cached_session(): with self.cached_session():
a = keras.layers.Input(shape=(2,)) a = keras.layers.Input(shape=(2,))
@ -213,9 +216,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
if h5py is None: if h5py is None:
return return
temp_dir = self.get_temp_dir() h5_path = self._save_model_dir('test.h5')
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')
num_hidden = 5 num_hidden = 5
input_dim = 3 input_dim = 3
@ -244,9 +245,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
exclude_formats=['tf_no_traces']) exclude_formats=['tf_no_traces'])
def test_nested_model_weight_loading(self): def test_nested_model_weight_loading(self):
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
temp_dir = self.get_temp_dir() saved_model_dir = self._save_model_dir()
self.addCleanup(shutil.rmtree, temp_dir)
saved_model_dir = os.path.join(temp_dir, 'saved_model')
batch_size = 5 batch_size = 5
shape = (None, None, 3) shape = (None, None, 3)
@ -284,9 +283,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
if h5py is None: if h5py is None:
return return
temp_dir = self.get_temp_dir() h5_path = self._save_model_dir('test.h5')
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')
num_hidden = 5 num_hidden = 5
input_dim = 3 input_dim = 3
@ -326,9 +323,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
if h5py is None: if h5py is None:
return return
temp_dir = self.get_temp_dir() h5_path = self._save_model_dir('test.h5')
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')
num_hidden = 5 num_hidden = 5
input_dim = 3 input_dim = 3
@ -367,6 +362,32 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose([3.5] * num_classes, self.assertAllClose([3.5] * num_classes,
keras.backend.get_value(model.layers[1].bias)) keras.backend.get_value(model.layers[1].bias))
@keras_parameterized.run_with_all_saved_model_formats(
exclude_formats=['tf_no_traces'])
@keras_parameterized.run_with_all_model_types
def test_load_weights_from_saved_model(self):
save_path = self._save_model_dir()
save_format = testing_utils.get_save_format()
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
# TODO(b/173646281): HDF5 format currently does not allow saving
# subclassed models.
return
with self.cached_session():
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
data = np.random.random((1, 3))
labels = np.random.random((1, 4))
model.compile(loss='mse', optimizer='rmsprop')
model.fit(data, labels)
model.save(save_path, save_format=save_format)
new_model = testing_utils.get_small_mlp(1, 4, input_dim=3)
if testing_utils.get_model_type() == 'subclass':
# Call on test data to build the model.
new_model.predict(data)
new_model.load_weights(save_path)
self.assertAllClose(model.weights, new_model.weights)
class SubclassedModel(training.Model): class SubclassedModel(training.Model):

View File

@ -321,3 +321,8 @@ def try_build_compiled_arguments(model):
'Compiled the loaded model, but the compiled metrics have yet to ' 'Compiled the loaded model, but the compiled metrics have yet to '
'be built. `model.compile_metrics` will be empty until you train ' 'be built. `model.compile_metrics` will be empty until you train '
'or evaluate the model.') 'or evaluate the model.')
def is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
filepath.endswith('.hdf5'))