Allow weights to be loaded from saved model.

(This change includes modifications to training_v1, which caused the TAP failures from the last CL).

PiperOrigin-RevId: 347838537
Change-Id: Ie05d51602eb770d9703ec9258576d925f139117e
This commit is contained in:
Katherine Wu 2020-12-16 09:36:57 -08:00 committed by TensorFlower Gardener
parent 4f2d645a80
commit 436808c7b6
7 changed files with 94 additions and 40 deletions

View File

@ -32,6 +32,8 @@
* `tf.keras`:
* Improvements to Keras preprocessing layers:
* Discretization combiner implemented, with additional arg `epsilon`.
* Improvements to model saving/loading:
* `model.load_weights` now accepts paths to saved models.
* `tf.data`:
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used

View File

@ -83,6 +83,8 @@ py_library(
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_util",
"//tensorflow/python/profiler:trace",
"//tensorflow/python/saved_model:constants",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/training/tracking:data_structures",
"//tensorflow/tools/docs:doc_controls",

View File

@ -55,6 +55,7 @@ from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import generic_utils
@ -72,6 +73,8 @@ from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace
from tensorflow.python.saved_model import constants as sm_constants
from tensorflow.python.saved_model import loader_impl as sm_loader
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training.tracking import base as trackable
@ -2114,7 +2117,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
"""
self._assert_weights_created()
filepath = path_to_string(filepath)
filepath_is_h5 = _is_hdf5_filepath(filepath)
filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
if save_format is None:
if filepath_is_h5:
save_format = 'h5'
@ -2202,7 +2205,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
Arguments:
filepath: String, path to the weights file to load. For weight files in
TensorFlow format, this is the file prefix (the same as was passed
to `save_weights`).
to `save_weights`). This can also be a path to a SavedModel
saved from `model.save`.
by_name: Boolean, whether to load weights by name or by topological
order. Only topological loading is supported for weight files in
TensorFlow format.
@ -2229,7 +2233,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
"""
if dist_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and
(not _is_hdf5_filepath(filepath))):
(not saving_utils.is_hdf5_filepath(filepath))):
raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.')
if skip_mismatch and not by_name:
@ -2237,16 +2241,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
'When calling model.load_weights, skip_mismatch can only be set to '
'True when by_name is True.')
filepath = path_to_string(filepath)
if _is_hdf5_filepath(filepath):
save_format = 'h5'
else:
try:
py_checkpoint_reader.NewCheckpointReader(filepath)
save_format = 'tf'
except errors_impl.DataLossError:
# The checkpoint is not readable in TensorFlow format. Try HDF5.
save_format = 'h5'
filepath, save_format = _detect_save_format(filepath)
if save_format == 'tf':
status = self._trackable_saver.restore(filepath, options)
if by_name:
@ -2851,6 +2846,40 @@ def _disallow_inside_tf_function(method_name):
raise RuntimeError(error_msg)
def _is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
filepath.endswith('.hdf5'))
def _detect_save_format(filepath):
"""Returns path to weights file and save format."""
filepath = path_to_string(filepath)
if saving_utils.is_hdf5_filepath(filepath):
return filepath, 'h5'
# Filepath could be a TensorFlow checkpoint file prefix or SavedModel
# directory. It's possible for filepath to be both a prefix and directory.
# Prioritize checkpoint over SavedModel.
if _is_readable_tf_checkpoint(filepath):
save_format = 'tf'
elif sm_loader.contains_saved_model(filepath):
ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
sm_constants.VARIABLES_FILENAME)
if _is_readable_tf_checkpoint(ckpt_path):
filepath = ckpt_path
save_format = 'tf'
else:
raise ValueError('Unable to load weights. filepath {} appears to be a '
'SavedModel directory, but checkpoint either doesn\'t '
'exist, or is incorrectly formatted.'.format(filepath))
else:
# Not a TensorFlow checkpoint. This filepath is likely an H5 file that
# doesn't have the hdf5/keras extensions.
save_format = 'h5'
return filepath, save_format
def _is_readable_tf_checkpoint(filepath):
try:
py_checkpoint_reader.NewCheckpointReader(filepath)
return True
except errors_impl.DataLossError:
# The checkpoint is not readable in TensorFlow format.
return False

View File

@ -55,6 +55,7 @@ from tensorflow.python.keras.engine import training_utils_v1
from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
@ -229,7 +230,7 @@ class Model(training_lib.Model):
"""
if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
if (self._distribution_strategy.extended.steps_per_run > 1 and
(not training_lib._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
(not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access
raise ValueError('Load weights is not yet supported with TPUStrategy '
'with steps_per_run greater than 1.')
return super(Model, self).load_weights(filepath, by_name, skip_mismatch)

View File

@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import six
from tensorflow.python import tf2
from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
from tensorflow.python.keras.saving.saved_model import load_context
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
@ -39,12 +39,6 @@ except ImportError:
h5py = None
# pylint: enable=g-import-not-at-top
_HDF5_EXTENSIONS = ['.h5', '.hdf5', '.keras']
# TODO(kathywu): Remove this when Keras SavedModel is not experimental.
_KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = True
@keras_export('keras.models.save_model')
def save_model(model,
@ -140,7 +134,7 @@ def save_model(model,
if (save_format == 'h5' or
(h5py is not None and isinstance(filepath, h5py.File)) or
os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS):
saving_utils.is_hdf5_filepath(filepath)):
# TODO(b/130258301): add utility method for detecting model type.
if (not model._is_graph_network and # pylint:disable=protected-access
not isinstance(model, sequential.Sequential)):

View File

@ -54,11 +54,14 @@ except ImportError:
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname)
@keras_parameterized.run_with_all_weight_formats
def test_weight_loading(self):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
saved_model_dir = os.path.join(temp_dir, 'saved_model')
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
with self.cached_session():
a = keras.layers.Input(shape=(2,))
@ -213,9 +216,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
if h5py is None:
return
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')
h5_path = self._save_model_dir('test.h5')
num_hidden = 5
input_dim = 3
@ -244,9 +245,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
exclude_formats=['tf_no_traces'])
def test_nested_model_weight_loading(self):
save_format = testing_utils.get_save_format()
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
saved_model_dir = os.path.join(temp_dir, 'saved_model')
saved_model_dir = self._save_model_dir()
batch_size = 5
shape = (None, None, 3)
@ -284,9 +283,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
if h5py is None:
return
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')
h5_path = self._save_model_dir('test.h5')
num_hidden = 5
input_dim = 3
@ -326,9 +323,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
if h5py is None:
return
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')
h5_path = self._save_model_dir('test.h5')
num_hidden = 5
input_dim = 3
@ -367,6 +362,32 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose([3.5] * num_classes,
keras.backend.get_value(model.layers[1].bias))
@keras_parameterized.run_with_all_saved_model_formats(
exclude_formats=['tf_no_traces'])
@keras_parameterized.run_with_all_model_types
def test_load_weights_from_saved_model(self):
save_path = self._save_model_dir()
save_format = testing_utils.get_save_format()
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
# TODO(b/173646281): HDF5 format currently does not allow saving
# subclassed models.
return
with self.cached_session():
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
data = np.random.random((1, 3))
labels = np.random.random((1, 4))
model.compile(loss='mse', optimizer='rmsprop')
model.fit(data, labels)
model.save(save_path, save_format=save_format)
new_model = testing_utils.get_small_mlp(1, 4, input_dim=3)
if testing_utils.get_model_type() == 'subclass':
# Call on test data to build the model.
new_model.predict(data)
new_model.load_weights(save_path)
self.assertAllClose(model.weights, new_model.weights)
class SubclassedModel(training.Model):

View File

@ -321,3 +321,8 @@ def try_build_compiled_arguments(model):
'Compiled the loaded model, but the compiled metrics have yet to '
'be built. `model.compile_metrics` will be empty until you train '
'or evaluate the model.')
def is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
filepath.endswith('.hdf5'))