Add load option for loading SavedModel from specific io_device for distributed training.

A new class LoadOptions is created similar to the existing SavedOptions.  The option experimental_io_device is the only option added at this time and usd to set the io_device when loading a SavedModel for distributed training.

PiperOrigin-RevId: 316557681
Change-Id: If3f1eae18b09085ff11dc8a6882fabcb18f5f48e
This commit is contained in:
Ken Franko 2020-06-15 15:28:17 -07:00 committed by TensorFlower Gardener
parent 5d4c6e105f
commit 67fb07ba9f
25 changed files with 144 additions and 32 deletions

View File

@ -2078,7 +2078,11 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
save_relative_paths=True,
all_model_checkpoint_paths=[filepath])
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
def load_weights(self,
filepath,
by_name=False,
skip_mismatch=False,
options=None):
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
If `by_name` is False weights are loaded based on the network's
@ -2108,6 +2112,8 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
skip_mismatch: Boolean, whether to skip loading of layers where there is
a mismatch in the number of weights, or a mismatch in the shape of
the weight (only valid when `by_name=True`).
options: Optional `tf.train.CheckpointOptions` object that specifies
options for loading weights.
Returns:
When loading a weight file in TensorFlow format, returns the same status
@ -2145,7 +2151,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# The checkpoint is not readable in TensorFlow format. Try HDF5.
save_format = 'h5'
if save_format == 'tf':
status = self._trackable_saver.restore(filepath)
status = self._trackable_saver.restore(filepath, options)
if by_name:
raise NotImplementedError(
'Weights may only be loaded based on topology into Models when '

View File

@ -135,7 +135,7 @@ def save_model(model,
@keras_export('keras.models.load_model')
def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin
def load_model(filepath, custom_objects=None, compile=True, options=None): # pylint: disable=redefined-builtin
"""Loads a model saved via `model.save()`.
Usage:
@ -162,6 +162,8 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
considered during deserialization.
compile: Boolean, whether to compile the model
after loading.
options: Optional `tf.saved_model.LoadOptions` object that specifies
options for loading from SavedModel.
Returns:
A Keras model instance. If the original model was compiled, and saved with
@ -182,7 +184,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
filepath = path_to_string(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile)
return saved_model_load.load(filepath, compile, options)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '

View File

@ -90,7 +90,7 @@ KERAS_OBJECT_IDENTIFIERS = (
'_tf_keras_rnn_layer')
def load(path, compile=True): # pylint: disable=redefined-builtin
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
"""Loads Keras objects from a SavedModel.
Any Keras layer or model saved to the SavedModel will be loaded back
@ -107,13 +107,18 @@ def load(path, compile=True): # pylint: disable=redefined-builtin
Args:
path: Path to SavedModel.
compile: If true, compile the model after loading it.
options: Optional `tf.saved_model.LoadOptions` object that specifies
options for loading from SavedModel.
Returns:
Object loaded from SavedModel.
"""
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
# TODO(kathywu): Add code to load from objects that contain all endpoints
model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
model = tf_load.load_internal(
path, options=options, loader_cls=KerasObjectLoader)
# pylint: disable=protected-access
if isinstance(model, training_lib.Model) and compile:

View File

@ -348,6 +348,7 @@ py_library(
deps = [
":constants",
":function_deserialization",
":load_options",
":load_v1_in_v2",
":loader",
":nested_structure_coder",
@ -522,6 +523,13 @@ py_library(
],
)
py_library(
name = "load_options",
srcs = ["load_options.py"],
deps = [
],
)
py_library(
name = "method_name_updater",
srcs = ["method_name_updater.py"],

View File

@ -37,11 +37,13 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import load_options
from tensorflow.python.saved_model import load_v1_in_v2
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
@ -105,7 +107,8 @@ class _WrapperFunction(function.ConcreteFunction):
class Loader(object):
"""Helper class to load an object-based SavedModel."""
def __init__(self, object_graph_proto, saved_model_proto, export_dir):
def __init__(self, object_graph_proto, saved_model_proto, export_dir,
ckpt_options):
meta_graph = saved_model_proto.meta_graphs[0]
self._asset_file_def = meta_graph.asset_file_def
self._operation_attributes = {
@ -115,6 +118,7 @@ class Loader(object):
self._concrete_functions = (
function_deserialization.load_function_def_library(
meta_graph.graph_def.library))
self._checkpoint_options = ckpt_options
for name, concrete_function in self._concrete_functions.items():
# Wrap all the concrete function so that they are capable of dealing with
@ -306,9 +310,10 @@ class Loader(object):
with ops.device("CPU"):
saver._file_prefix_placeholder = constant_op.constant(variables_path)
if self._expect_partial_checkpoint:
load_status = saver.restore(variables_path).expect_partial()
load_status = saver.restore(variables_path,
self._checkpoint_options).expect_partial()
else:
load_status = saver.restore(variables_path)
load_status = saver.restore(variables_path, self._checkpoint_options)
load_status.assert_existing_objects_matched()
checkpoint = load_status._checkpoint
@ -491,7 +496,7 @@ def _call_attribute(instance, *args, **kwargs):
@tf_export("saved_model.load", v1=["saved_model.load_v2"])
def load(export_dir, tags=None):
def load(export_dir, tags=None, options=None):
"""Load a SavedModel from `export_dir`.
Signatures associated with the SavedModel are available as functions:
@ -569,6 +574,8 @@ def load(export_dir, tags=None):
tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
if the SavedModel contains a single MetaGraph, as for those exported from
`tf.saved_model.save`.
options: Optional, `tf.saved_model.LoadOptions` object that specifies
options for loading.
Returns:
A trackable object with a `signatures` attribute mapping from signature
@ -579,11 +586,12 @@ def load(export_dir, tags=None):
Raises:
ValueError: If `tags` don't match a MetaGraph in the SavedModel.
"""
return load_internal(export_dir, tags)
return load_internal(export_dir, tags, options)
def load_internal(export_dir, tags=None, loader_cls=Loader):
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader):
"""Loader implementation."""
options = options or load_options.LoadOptions()
if tags is not None and not isinstance(tags, set):
# Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
# sequences for nest.flatten, so we put those through as-is.
@ -602,10 +610,12 @@ def load_internal(export_dir, tags=None, loader_cls=Loader):
"it, pass 'None', or pass matching tags.")
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
object_graph_proto = meta_graph_def.object_graph_def
ckpt_options = checkpoint_options.CheckpointOptions(
experimental_io_device=options.experimental_io_device)
with ops.init_scope():
loader = loader_cls(object_graph_proto,
saved_model_proto,
export_dir)
loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
ckpt_options)
root = loader.get(0)
if isinstance(loader, Loader):
root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)

View File

@ -0,0 +1,57 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Options for saving SavedModels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
@tf_export("saved_model.LoadOptions", v1=[])
class LoadOptions(object):
"""Options for loading a SavedModel.
This function may be used in the `options` argument in functions that
load a SavedModel (`tf.saved_model.load`, `tf.keras.models.load_model`).
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("experimental_io_device",)
def __init__(self,
experimental_io_device=None):
"""Creates an object that stores options for SavedModel loading.
Args:
experimental_io_device: string. Applies in a distributed setting.
Tensorflow device to use to access the filesystem. If `None` (default)
then for each variable the filesystem is accessed from the CPU:0 device
of the host where that variable is assigned. If specified, the
filesystem is instead accessed from that device for all variables.
This is for example useful if you want to load from a local directory,
such as "/tmp" when running in a distributed setting. In that case
pass a device for the host where the "/tmp" directory is accessible.
Example:
load_options = tf.saved_model.LoadOptions(experimental_io_device=
'/job:localhost')
restoredmodel = tf.keras.models.load_model(saved_model_path,
options=load_options)
"""
self.experimental_io_device = experimental_io_device

View File

@ -56,6 +56,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import load_options
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
@ -1788,6 +1789,12 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(imported2.f(rt, 2), [[3, 4], [5]])
self.assertAllEqual(imported2.f(rt, 3), [[4, 5], [6]])
def test_accepts_io_device(self, cycles):
options = load_options.LoadOptions()
self.assertIsNone(options.experimental_io_device)
options = load_options.LoadOptions(experimental_io_device="/job:localhost")
self.assertEqual("/job:localhost", options.experimental_io_device)
class SingleCycleTests(test.TestCase, parameterized.TestCase):

View File

@ -258,7 +258,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -264,7 +264,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -259,7 +259,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -259,7 +259,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -258,7 +258,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -264,7 +264,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "load_model"
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "model_from_config"

View File

@ -182,7 +182,7 @@ tf_module {
}
member_method {
name: "load_v2"
argspec: "args=[\'export_dir\', \'tags\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'export_dir\', \'tags\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "main_op_with_restore"

View File

@ -258,7 +258,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -264,7 +264,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -259,7 +259,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -259,7 +259,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -258,7 +258,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -264,7 +264,7 @@ tf_class {
}
member_method {
name: "load_weights"
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
argspec: "args=[\'self\', \'filepath\', \'by_name\', \'skip_mismatch\', \'options\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "make_predict_function"

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "load_model"
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "model_from_config"

View File

@ -0,0 +1,13 @@
path: "tensorflow.saved_model.LoadOptions"
tf_class {
is_instance: "<class \'tensorflow.python.saved_model.load_options.LoadOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_io_device"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'experimental_io_device\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -44,6 +44,10 @@ tf_module {
name: "GPU"
mtype: "<type \'str\'>"
}
member {
name: "LoadOptions"
mtype: "<type \'type\'>"
}
member {
name: "PREDICT_INPUTS"
mtype: "<type \'str\'>"
@ -110,7 +114,7 @@ tf_module {
}
member_method {
name: "load"
argspec: "args=[\'export_dir\', \'tags\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'export_dir\', \'tags\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "save"