STT-tensorflow/tensorflow/python/saved_model/save_options.py
Cesar Crusius a4f4855c82 Optionally save MirroredVariable components.
Write MirroredVariable components to the newly introduced
`experimental_distributed_variable_components` protobuf field when the
EXPAND_DISTRIBUTED_VARIABLES SaveOption is set. This is currently not
supported by any loader.

PiperOrigin-RevId: 335670847
Change-Id: I1c38ae132e4b2cda52adafa819c1779488031f20
2020-10-06 11:16:47 -07:00

187 lines
7.5 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Options for saving SavedModels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
import six
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@tf_export("saved_model.experimental.VariablePolicy")
class VariablePolicy(enum.Enum):
"""Enum defining options for variable handling when saving.
NONE
No policy applied: Distributed variables are saved as one variable, with no
device attached.
SAVE_VARIABLE_DEVICES
When saving variables, also save their device assignment.
This is useful if one wants to hardcode devices in saved models, but it also
makes them non-portable if soft device placement is disabled (more details
in `tf.config.set_soft_device_placement`). This is currently not
fully supported by `saved_model.load`, and is mainly intended to be used
when one will be reading the saved model at a lower API level. In the
example below, the graph saved by the call to `saved_model.save` will have
the variable devices correctly specified:
```python
exported = tf.train.Checkpoint()
with tf.device('/GPU:0'):
exported.x_gpu = tf.Variable(1.0)
with tf.device('/CPU:0'):
exported.x_cpu = tf.Variable(1.0)
tf.saved_model.save(exported, export_dir,
options = tf.saved_model.SaveOptions(
experimental_variable_policy=
tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
```
Distributed variables are still saved as one variable under this policy.
EXPAND_DISTRIBUTED_VARIABLES
Distributed variables will be saved with information about their components,
allowing for their restoration on load. Also, the saved graph will contain
references to those variables. This is useful when one wants to use the
model for training in environments where the original distribution strategy
is not available.
"""
NONE = None
SAVE_VARIABLE_DEVICES = "save_variable_devices"
EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
def _save_variable_devices(self):
"""Checks whether variable devices should be saved."""
return self != VariablePolicy.NONE
def _expand_distributed_variables(self):
"""Checks whether distributed variables should be expanded."""
return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
@staticmethod
def from_obj(obj):
"""Tries to convert `obj` to a VariablePolicy instance."""
if obj is None:
return VariablePolicy.NONE
if isinstance(obj, VariablePolicy):
return obj
key = str(obj).lower()
for policy in VariablePolicy:
if key == policy.value:
return policy
raise ValueError('Invalid VariablePolicy value "%s".' % obj)
@tf_export("saved_model.SaveOptions")
class SaveOptions(object):
"""Options for saving to SavedModel.
This function may be used in the `options` argument in functions that
save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
"experimental_io_device", "experimental_variable_policy")
def __init__(self,
namespace_whitelist=None,
save_debug_info=False,
function_aliases=None,
experimental_io_device=None,
experimental_variable_policy=None):
"""Creates an object that stores options for SavedModel saving.
Args:
namespace_whitelist: List of strings containing op namespaces to whitelist
when saving a model. Saving an object that uses namespaced ops must
explicitly add all namespaces to the whitelist. The namespaced ops must
be registered into the framework when loading the SavedModel.
save_debug_info: Boolean indicating whether debug information is saved. If
True, then a debug/saved_model_debug_info.pb file will be written with
the contents of a GraphDebugInfo binary protocol buffer containing stack
trace information for all ops and functions that are saved.
function_aliases: Python dict. Mapping from string to object returned by
@tf.function. A single tf.function can generate many ConcreteFunctions.
If a downstream tool wants to refer to all concrete functions generated
by a single tf.function you can use the `function_aliases` argument to
store a map from the alias name to all concrete function names.
E.g.
```python
class MyModel:
@tf.function
def func():
...
@tf.function
def serve():
...
func()
model = MyModel()
signatures = {
'serving_default': model.serve.get_concrete_function(),
}
options = tf.saved_model.SaveOptions(function_aliases={
'my_func': func,
})
tf.saved_model.save(model, export_dir, signatures, options)
```
experimental_io_device: string. Applies in a distributed setting.
Tensorflow device to use to access the filesystem. If `None` (default)
then for each variable the filesystem is accessed from the CPU:0 device
of the host where that variable is assigned. If specified, the
filesystem is instead accessed from that device for all variables.
This is for example useful if you want to save to a local directory,
such as "/tmp" when running in a distributed setting. In that case pass
a device for the host where the "/tmp" directory is accessible.
experimental_variable_policy: The policy to apply to variables when
saving. This is either a `saved_model.experimental.VariablePolicy` enum
instance or one of its value strings (case is not important). See that
enum documentation for details. A value of `None` corresponds to the
default policy.
"""
self.namespace_whitelist = _validate_namespace_whitelist(
namespace_whitelist)
self.save_debug_info = save_debug_info
self.function_aliases = function_aliases if function_aliases else dict()
self.experimental_io_device = experimental_io_device
self.experimental_variable_policy = (
VariablePolicy.from_obj(experimental_variable_policy))
def _validate_namespace_whitelist(namespace_whitelist):
"""Validates namespace whitelist argument."""
if namespace_whitelist is None:
return []
if not isinstance(namespace_whitelist, list):
raise TypeError("Namespace whitelist must be a list of strings.")
processed = []
for namespace in namespace_whitelist:
if not isinstance(namespace, six.string_types):
raise ValueError("Whitelisted namespace must be a string. Got: {} of type"
" {}.".format(namespace, type(namespace)))
processed.append(compat.as_str(namespace))
return processed