Write MirroredVariable components to the newly introduced `experimental_distributed_variable_components` protobuf field when the EXPAND_DISTRIBUTED_VARIABLES SaveOption is set. This is currently not supported by any loader. PiperOrigin-RevId: 335670847 Change-Id: I1c38ae132e4b2cda52adafa819c1779488031f20
187 lines
7.5 KiB
Python
187 lines
7.5 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Options for saving SavedModels."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import enum
|
|
import six
|
|
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export("saved_model.experimental.VariablePolicy")
|
|
class VariablePolicy(enum.Enum):
|
|
"""Enum defining options for variable handling when saving.
|
|
|
|
NONE
|
|
No policy applied: Distributed variables are saved as one variable, with no
|
|
device attached.
|
|
|
|
SAVE_VARIABLE_DEVICES
|
|
When saving variables, also save their device assignment.
|
|
This is useful if one wants to hardcode devices in saved models, but it also
|
|
makes them non-portable if soft device placement is disabled (more details
|
|
in `tf.config.set_soft_device_placement`). This is currently not
|
|
fully supported by `saved_model.load`, and is mainly intended to be used
|
|
when one will be reading the saved model at a lower API level. In the
|
|
example below, the graph saved by the call to `saved_model.save` will have
|
|
the variable devices correctly specified:
|
|
```python
|
|
exported = tf.train.Checkpoint()
|
|
with tf.device('/GPU:0'):
|
|
exported.x_gpu = tf.Variable(1.0)
|
|
with tf.device('/CPU:0'):
|
|
exported.x_cpu = tf.Variable(1.0)
|
|
tf.saved_model.save(exported, export_dir,
|
|
options = tf.saved_model.SaveOptions(
|
|
experimental_variable_policy=
|
|
tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
|
|
```
|
|
Distributed variables are still saved as one variable under this policy.
|
|
|
|
EXPAND_DISTRIBUTED_VARIABLES
|
|
Distributed variables will be saved with information about their components,
|
|
allowing for their restoration on load. Also, the saved graph will contain
|
|
references to those variables. This is useful when one wants to use the
|
|
model for training in environments where the original distribution strategy
|
|
is not available.
|
|
"""
|
|
|
|
NONE = None
|
|
|
|
SAVE_VARIABLE_DEVICES = "save_variable_devices"
|
|
|
|
EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
|
|
|
|
def _save_variable_devices(self):
|
|
"""Checks whether variable devices should be saved."""
|
|
return self != VariablePolicy.NONE
|
|
|
|
def _expand_distributed_variables(self):
|
|
"""Checks whether distributed variables should be expanded."""
|
|
return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
|
|
|
|
@staticmethod
|
|
def from_obj(obj):
|
|
"""Tries to convert `obj` to a VariablePolicy instance."""
|
|
if obj is None:
|
|
return VariablePolicy.NONE
|
|
if isinstance(obj, VariablePolicy):
|
|
return obj
|
|
key = str(obj).lower()
|
|
for policy in VariablePolicy:
|
|
if key == policy.value:
|
|
return policy
|
|
raise ValueError('Invalid VariablePolicy value "%s".' % obj)
|
|
|
|
|
|
@tf_export("saved_model.SaveOptions")
|
|
class SaveOptions(object):
|
|
"""Options for saving to SavedModel.
|
|
|
|
This function may be used in the `options` argument in functions that
|
|
save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
|
|
"""
|
|
|
|
# Define object attributes in __slots__ for improved memory and performance.
|
|
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
|
|
"experimental_io_device", "experimental_variable_policy")
|
|
|
|
def __init__(self,
|
|
namespace_whitelist=None,
|
|
save_debug_info=False,
|
|
function_aliases=None,
|
|
experimental_io_device=None,
|
|
experimental_variable_policy=None):
|
|
"""Creates an object that stores options for SavedModel saving.
|
|
|
|
Args:
|
|
namespace_whitelist: List of strings containing op namespaces to whitelist
|
|
when saving a model. Saving an object that uses namespaced ops must
|
|
explicitly add all namespaces to the whitelist. The namespaced ops must
|
|
be registered into the framework when loading the SavedModel.
|
|
save_debug_info: Boolean indicating whether debug information is saved. If
|
|
True, then a debug/saved_model_debug_info.pb file will be written with
|
|
the contents of a GraphDebugInfo binary protocol buffer containing stack
|
|
trace information for all ops and functions that are saved.
|
|
function_aliases: Python dict. Mapping from string to object returned by
|
|
@tf.function. A single tf.function can generate many ConcreteFunctions.
|
|
If a downstream tool wants to refer to all concrete functions generated
|
|
by a single tf.function you can use the `function_aliases` argument to
|
|
store a map from the alias name to all concrete function names.
|
|
E.g.
|
|
```python
|
|
class MyModel:
|
|
@tf.function
|
|
def func():
|
|
...
|
|
|
|
@tf.function
|
|
def serve():
|
|
...
|
|
func()
|
|
|
|
model = MyModel()
|
|
signatures = {
|
|
'serving_default': model.serve.get_concrete_function(),
|
|
}
|
|
options = tf.saved_model.SaveOptions(function_aliases={
|
|
'my_func': func,
|
|
})
|
|
tf.saved_model.save(model, export_dir, signatures, options)
|
|
```
|
|
experimental_io_device: string. Applies in a distributed setting.
|
|
Tensorflow device to use to access the filesystem. If `None` (default)
|
|
then for each variable the filesystem is accessed from the CPU:0 device
|
|
of the host where that variable is assigned. If specified, the
|
|
filesystem is instead accessed from that device for all variables.
|
|
|
|
This is for example useful if you want to save to a local directory,
|
|
such as "/tmp" when running in a distributed setting. In that case pass
|
|
a device for the host where the "/tmp" directory is accessible.
|
|
experimental_variable_policy: The policy to apply to variables when
|
|
saving. This is either a `saved_model.experimental.VariablePolicy` enum
|
|
instance or one of its value strings (case is not important). See that
|
|
enum documentation for details. A value of `None` corresponds to the
|
|
default policy.
|
|
"""
|
|
self.namespace_whitelist = _validate_namespace_whitelist(
|
|
namespace_whitelist)
|
|
self.save_debug_info = save_debug_info
|
|
self.function_aliases = function_aliases if function_aliases else dict()
|
|
self.experimental_io_device = experimental_io_device
|
|
self.experimental_variable_policy = (
|
|
VariablePolicy.from_obj(experimental_variable_policy))
|
|
|
|
|
|
def _validate_namespace_whitelist(namespace_whitelist):
|
|
"""Validates namespace whitelist argument."""
|
|
if namespace_whitelist is None:
|
|
return []
|
|
if not isinstance(namespace_whitelist, list):
|
|
raise TypeError("Namespace whitelist must be a list of strings.")
|
|
|
|
processed = []
|
|
for namespace in namespace_whitelist:
|
|
if not isinstance(namespace, six.string_types):
|
|
raise ValueError("Whitelisted namespace must be a string. Got: {} of type"
|
|
" {}.".format(namespace, type(namespace)))
|
|
processed.append(compat.as_str(namespace))
|
|
return processed
|