Allow weights to be loaded from saved model.
(This change includes modifications to training_v1, which caused the TAP failures from the last CL). PiperOrigin-RevId: 347838537 Change-Id: Ie05d51602eb770d9703ec9258576d925f139117e
This commit is contained in:
parent
4f2d645a80
commit
436808c7b6
@ -32,6 +32,8 @@
|
||||
* `tf.keras`:
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* Discretization combiner implemented, with additional arg `epsilon`.
|
||||
* Improvements to model saving/loading:
|
||||
* `model.load_weights` now accepts paths to saved models.
|
||||
|
||||
* `tf.data`:
|
||||
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
|
||||
|
@ -83,6 +83,8 @@ py_library(
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_util",
|
||||
"//tensorflow/python/profiler:trace",
|
||||
"//tensorflow/python/saved_model:constants",
|
||||
"//tensorflow/python/saved_model:loader",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso
|
||||
from tensorflow.python.keras.mixed_precision import policy
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
@ -72,6 +73,8 @@ from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import trace
|
||||
from tensorflow.python.saved_model import constants as sm_constants
|
||||
from tensorflow.python.saved_model import loader_impl as sm_loader
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import py_checkpoint_reader
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
@ -2114,7 +2117,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
"""
|
||||
self._assert_weights_created()
|
||||
filepath = path_to_string(filepath)
|
||||
filepath_is_h5 = _is_hdf5_filepath(filepath)
|
||||
filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
|
||||
if save_format is None:
|
||||
if filepath_is_h5:
|
||||
save_format = 'h5'
|
||||
@ -2202,7 +2205,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
Arguments:
|
||||
filepath: String, path to the weights file to load. For weight files in
|
||||
TensorFlow format, this is the file prefix (the same as was passed
|
||||
to `save_weights`).
|
||||
to `save_weights`). This can also be a path to a SavedModel
|
||||
saved from `model.save`.
|
||||
by_name: Boolean, whether to load weights by name or by topological
|
||||
order. Only topological loading is supported for weight files in
|
||||
TensorFlow format.
|
||||
@ -2229,7 +2233,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
"""
|
||||
if dist_utils.is_tpu_strategy(self._distribution_strategy):
|
||||
if (self._distribution_strategy.extended.steps_per_run > 1 and
|
||||
(not _is_hdf5_filepath(filepath))):
|
||||
(not saving_utils.is_hdf5_filepath(filepath))):
|
||||
raise ValueError('Load weights is not yet supported with TPUStrategy '
|
||||
'with steps_per_run greater than 1.')
|
||||
if skip_mismatch and not by_name:
|
||||
@ -2237,16 +2241,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
'When calling model.load_weights, skip_mismatch can only be set to '
|
||||
'True when by_name is True.')
|
||||
|
||||
filepath = path_to_string(filepath)
|
||||
if _is_hdf5_filepath(filepath):
|
||||
save_format = 'h5'
|
||||
else:
|
||||
try:
|
||||
py_checkpoint_reader.NewCheckpointReader(filepath)
|
||||
save_format = 'tf'
|
||||
except errors_impl.DataLossError:
|
||||
# The checkpoint is not readable in TensorFlow format. Try HDF5.
|
||||
save_format = 'h5'
|
||||
filepath, save_format = _detect_save_format(filepath)
|
||||
if save_format == 'tf':
|
||||
status = self._trackable_saver.restore(filepath, options)
|
||||
if by_name:
|
||||
@ -2851,6 +2846,40 @@ def _disallow_inside_tf_function(method_name):
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def _is_hdf5_filepath(filepath):
|
||||
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
|
||||
filepath.endswith('.hdf5'))
|
||||
def _detect_save_format(filepath):
|
||||
"""Returns path to weights file and save format."""
|
||||
|
||||
filepath = path_to_string(filepath)
|
||||
if saving_utils.is_hdf5_filepath(filepath):
|
||||
return filepath, 'h5'
|
||||
|
||||
# Filepath could be a TensorFlow checkpoint file prefix or SavedModel
|
||||
# directory. It's possible for filepath to be both a prefix and directory.
|
||||
# Prioritize checkpoint over SavedModel.
|
||||
if _is_readable_tf_checkpoint(filepath):
|
||||
save_format = 'tf'
|
||||
elif sm_loader.contains_saved_model(filepath):
|
||||
ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
|
||||
sm_constants.VARIABLES_FILENAME)
|
||||
if _is_readable_tf_checkpoint(ckpt_path):
|
||||
filepath = ckpt_path
|
||||
save_format = 'tf'
|
||||
else:
|
||||
raise ValueError('Unable to load weights. filepath {} appears to be a '
|
||||
'SavedModel directory, but checkpoint either doesn\'t '
|
||||
'exist, or is incorrectly formatted.'.format(filepath))
|
||||
else:
|
||||
# Not a TensorFlow checkpoint. This filepath is likely an H5 file that
|
||||
# doesn't have the hdf5/keras extensions.
|
||||
save_format = 'h5'
|
||||
return filepath, save_format
|
||||
|
||||
|
||||
def _is_readable_tf_checkpoint(filepath):
|
||||
try:
|
||||
py_checkpoint_reader.NewCheckpointReader(filepath)
|
||||
return True
|
||||
except errors_impl.DataLossError:
|
||||
# The checkpoint is not readable in TensorFlow format.
|
||||
return False
|
||||
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.keras.engine import training_utils_v1
|
||||
from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
|
||||
from tensorflow.python.keras.mixed_precision import policy
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import model_serialization
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
@ -229,7 +230,7 @@ class Model(training_lib.Model):
|
||||
"""
|
||||
if distributed_training_utils.is_tpu_strategy(self._distribution_strategy):
|
||||
if (self._distribution_strategy.extended.steps_per_run > 1 and
|
||||
(not training_lib._is_hdf5_filepath(filepath))): # pylint: disable=protected-access
|
||||
(not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access
|
||||
raise ValueError('Load weights is not yet supported with TPUStrategy '
|
||||
'with steps_per_run greater than 1.')
|
||||
return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
|
||||
|
@ -18,11 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
|
||||
from tensorflow.python.keras.saving.saved_model import load_context
|
||||
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
|
||||
@ -39,12 +39,6 @@ except ImportError:
|
||||
h5py = None
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
_HDF5_EXTENSIONS = ['.h5', '.hdf5', '.keras']
|
||||
|
||||
|
||||
# TODO(kathywu): Remove this when Keras SavedModel is not experimental.
|
||||
_KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = True
|
||||
|
||||
|
||||
@keras_export('keras.models.save_model')
|
||||
def save_model(model,
|
||||
@ -140,7 +134,7 @@ def save_model(model,
|
||||
|
||||
if (save_format == 'h5' or
|
||||
(h5py is not None and isinstance(filepath, h5py.File)) or
|
||||
os.path.splitext(filepath)[1] in _HDF5_EXTENSIONS):
|
||||
saving_utils.is_hdf5_filepath(filepath)):
|
||||
# TODO(b/130258301): add utility method for detecting model type.
|
||||
if (not model._is_graph_network and # pylint:disable=protected-access
|
||||
not isinstance(model, sequential.Sequential)):
|
||||
|
@ -54,11 +54,14 @@ except ImportError:
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
return os.path.join(temp_dir, dirname)
|
||||
|
||||
@keras_parameterized.run_with_all_weight_formats
|
||||
def test_weight_loading(self):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
saved_model_dir = os.path.join(temp_dir, 'saved_model')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
with self.cached_session():
|
||||
a = keras.layers.Input(shape=(2,))
|
||||
@ -213,9 +216,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
if h5py is None:
|
||||
return
|
||||
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||
h5_path = self._save_model_dir('test.h5')
|
||||
|
||||
num_hidden = 5
|
||||
input_dim = 3
|
||||
@ -244,9 +245,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
exclude_formats=['tf_no_traces'])
|
||||
def test_nested_model_weight_loading(self):
|
||||
save_format = testing_utils.get_save_format()
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
saved_model_dir = os.path.join(temp_dir, 'saved_model')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
batch_size = 5
|
||||
shape = (None, None, 3)
|
||||
@ -284,9 +283,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
if h5py is None:
|
||||
return
|
||||
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||
h5_path = self._save_model_dir('test.h5')
|
||||
|
||||
num_hidden = 5
|
||||
input_dim = 3
|
||||
@ -326,9 +323,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
if h5py is None:
|
||||
return
|
||||
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||
h5_path = self._save_model_dir('test.h5')
|
||||
|
||||
num_hidden = 5
|
||||
input_dim = 3
|
||||
@ -367,6 +362,32 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose([3.5] * num_classes,
|
||||
keras.backend.get_value(model.layers[1].bias))
|
||||
|
||||
@keras_parameterized.run_with_all_saved_model_formats(
|
||||
exclude_formats=['tf_no_traces'])
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_load_weights_from_saved_model(self):
|
||||
save_path = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
|
||||
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
|
||||
# TODO(b/173646281): HDF5 format currently does not allow saving
|
||||
# subclassed models.
|
||||
return
|
||||
|
||||
with self.cached_session():
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
data = np.random.random((1, 3))
|
||||
labels = np.random.random((1, 4))
|
||||
model.compile(loss='mse', optimizer='rmsprop')
|
||||
model.fit(data, labels)
|
||||
model.save(save_path, save_format=save_format)
|
||||
new_model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
if testing_utils.get_model_type() == 'subclass':
|
||||
# Call on test data to build the model.
|
||||
new_model.predict(data)
|
||||
new_model.load_weights(save_path)
|
||||
self.assertAllClose(model.weights, new_model.weights)
|
||||
|
||||
|
||||
class SubclassedModel(training.Model):
|
||||
|
||||
|
@ -321,3 +321,8 @@ def try_build_compiled_arguments(model):
|
||||
'Compiled the loaded model, but the compiled metrics have yet to '
|
||||
'be built. `model.compile_metrics` will be empty until you train '
|
||||
'or evaluate the model.')
|
||||
|
||||
|
||||
def is_hdf5_filepath(filepath):
|
||||
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
|
||||
filepath.endswith('.hdf5'))
|
||||
|
Loading…
Reference in New Issue
Block a user