When saving to SavedModel, functions that capture cached values are now wrapped with functions that capture the original variable.

PiperOrigin-RevId: 292041916
Change-Id: I69a3174c52d526ad0ec48d7cd533ff95cb44747a
This commit is contained in:
Katherine Wu 2020-01-28 16:48:06 -08:00 committed by TensorFlower Gardener
parent b862920996
commit 7705f70c14
6 changed files with 213 additions and 22 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import contextlib
import functools
import weakref
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
@ -347,6 +348,7 @@ class BaseResourceVariable(variables.VariableV1):
cached_value=None,
save_slice_info=None,
handle_deleter=None,
caching_device=None,
**unused_kwargs):
"""Creates a variable from a handle.
@ -387,6 +389,11 @@ class BaseResourceVariable(variables.VariableV1):
save_slice_info: Metadata for variable partitioning.
handle_deleter: EagerResourceDeleter responsible for cleaning up the
handle.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
"""
with ops.init_scope():
self._in_graph_mode = not context.executing_eagerly()
@ -401,6 +408,7 @@ class BaseResourceVariable(variables.VariableV1):
self._initializer_op = initializer_op
self._is_initialized_op = is_initialized_op
self._graph_element = graph_element
self._caching_device = caching_device
self._cached_value = cached_value
self._distribute_strategy = distribute_strategy
# Store the graph key so optimizers know how to only retrieve variables from
@ -612,9 +620,19 @@ class BaseResourceVariable(variables.VariableV1):
def _read_variable_op(self):
variable_accessed(self)
result = gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
_maybe_set_handle_data(self._dtype, self._handle, result)
def read_and_set_handle():
result = gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
_maybe_set_handle_data(self._dtype, self._handle, result)
return result
if getattr(self, "_caching_device", None) is not None:
with ops.colocate_with(None, ignore_existing=True):
with ops.device(self._caching_device):
result = read_and_set_handle()
else:
result = read_and_set_handle()
if not context.executing_eagerly():
# Note that if a control flow context is active the input of the read op
@ -1614,6 +1632,12 @@ class ResourceVariable(BaseResourceVariable):
_maybe_set_handle_data(dtype, handle, cached_value)
else:
cached_value = None
if cached_value is not None:
# Store the variable object so that the original variable can be
# accessed to generate functions that are compatible with SavedModel.
cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access
if not context.executing_eagerly():
# Eager variables are only added to collections if they are part of an
# eager variable store (otherwise in an interactive session they would
@ -1629,7 +1653,7 @@ class ResourceVariable(BaseResourceVariable):
name=name, unique_id=unique_id, handle_name=handle_name,
graph_element=graph_element, initial_value=initial_value,
initializer_op=initializer_op, is_initialized_op=is_initialized_op,
cached_value=cached_value)
cached_value=cached_value, caching_device=caching_device)
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""

View File

@ -19,8 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.util import compat
from tensorflow.python.util import nest
def _serialize_function_spec(function_spec, coder):
@ -72,17 +75,19 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
return concrete_function_proto
def serialize_bare_concrete_function(concrete_function):
def serialize_bare_concrete_function(concrete_function, name_map):
"""Build a SavedBareConcreteFunction."""
# pylint: disable=protected-access
name = name_map.get(compat.as_text(concrete_function.name),
concrete_function.name)
return saved_object_graph_pb2.SavedBareConcreteFunction(
concrete_function_name=concrete_function.name,
concrete_function_name=name,
allowed_positional_arguments=concrete_function._num_positional_args,
argument_keywords=concrete_function._arg_keywords)
# pylint: enable=protected-access
def serialize_function(function):
def serialize_function(function, name_map):
"""Build a SavedFunction proto."""
coder = nested_structure_coder.StructureCoder()
proto = saved_object_graph_pb2.SavedFunction()
@ -92,5 +97,65 @@ def serialize_function(function):
all_concrete_functions = \
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
for concrete_function in all_concrete_functions:
proto.concrete_functions.append(concrete_function.name)
proto.concrete_functions.append(
name_map.get(compat.as_text(concrete_function.name),
concrete_function.name))
return proto
def wrap_cached_variables(concrete_function):
"""Wraps the concrete function if it uses cached read tensors.
This function creates a new concrete function that captures variables
instead of the cached read tensors.
Args:
concrete_function: A Concrete function that maybe captures cached read
tensors.
Returns:
A concrete function that wraps the original concrete function, which
captures variables instead. If the original function did not capture any
cached values, then the function is not wrapped and the original object is
returned.
"""
outer_graph = func_graph_module.FuncGraph(
"{}_no_cache".format(concrete_function.graph.name))
captures = concrete_function.graph._captures # pylint: disable=protected-access
mapped_captures = None
remapped_captures = {}
# Update the external captures to use read tensors generated in the outer
# graph.
with outer_graph.as_default():
for capture, placeholder in concrete_function.graph.captures:
cached_variable = getattr(capture, "_cached_variable", None)
if cached_variable is None:
continue
cached_variable = cached_variable()
new_cached_value = cached_variable.read_value()
remapped_captures[id(capture)] = captures[id(capture)]
captures[id(capture)] = (new_cached_value, placeholder)
mapped_captures = True
if not mapped_captures:
return concrete_function
inner_concrete = defun.ConcreteFunction(concrete_function.graph)
def wrap_function(*args):
return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access
args = nest.flatten(concrete_function.structured_input_signature,
expand_composites=True)
func_graph_module.func_graph_from_py_func(
None, wrap_function, args=tuple(args), kwargs={},
func_graph=outer_graph)
fn = defun.ConcreteFunction(outer_graph)
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
# Return the captures to their original values
for key, capture in remapped_captures.items():
captures[key] = capture
return fn

View File

@ -27,6 +27,7 @@ import weakref
from absl.testing import parameterized
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
@ -1826,6 +1827,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
rt = ragged_factory_ops.constant([[1, 2], [3]])
self.assertAllEqual(imported2.f(rt), [[2, 3], [4]])
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@parameterized.named_parameters(
dict(testcase_name="ReloadOnce", cycles=1),
@ -1943,6 +1945,53 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
"object has an attribute named a, which is reserved."):
save.save(root, path)
def test_save_cached_variable(self):
with ops.Graph().as_default(), session_lib.Session() as session:
obj = tracking.AutoTrackable()
obj.v = variables.Variable(2., caching_device=lambda op: op.device)
obj.w = variables.Variable(3.)
session.run([obj.v.initializer, obj.w.initializer])
@def_function.function
def total():
return obj.v + obj.w
@def_function.function(input_signature=[tensor_spec.TensorSpec([])])
def wrapped_total(x):
return total() + x
@def_function.function
def increment_v(x):
obj.v.assign_add(x)
session.run(increment_v(constant_op.constant(3.))) # generate signatures
self.assertAllClose(8, total())
self.assertAllClose(13, wrapped_total(constant_op.constant(5.)))
obj.total = total
obj.wrapped_total = wrapped_total.get_concrete_function()
obj.increment_v = increment_v
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(obj, save_dir, signatures=total.get_concrete_function())
imported = load.load(save_dir)
session.run(variables.global_variables_initializer())
self.assertAllClose(8, imported.total())
session.run(imported.increment_v(4))
self.assertAllClose(12, imported.total())
self.assertAllClose(15, imported.wrapped_total(constant_op.constant(3.)))
self.assertAllClose({"output_0": 12},
imported.signatures["serving_default"]())
# Try loading and running the function in eager mode
imported = load.load(save_dir)
self.assertAllClose(8, imported.total())
imported.increment_v(5)
self.assertAllClose(13, imported.total())
self.assertAllClose(13.5, imported.wrapped_total(constant_op.constant(.5)))
self.assertAllClose({"output_0": 13},
imported.signatures["serving_default"]())
if __name__ == "__main__":
test.main()

View File

@ -159,7 +159,14 @@ class _SaveableView(object):
ignored.
"""
def __init__(self, checkpoint_view):
def __init__(self, checkpoint_view, wrapped_functions=None):
"""Initializes a SaveableView.
Args:
checkpoint_view: A GraphView object.
wrapped_functions: Dictionary that maps concrete functions to functions
that do not capture cached variable values.
"""
self.checkpoint_view = checkpoint_view
trackable_objects, node_ids, slot_variables = (
self.checkpoint_view.objects_ids_and_slot_variables())
@ -169,6 +176,15 @@ class _SaveableView(object):
self.slot_variables = slot_variables
self.concrete_functions = []
# Maps functions -> wrapped functions that capture variables
self.wrapped_functions = wrapped_functions or {}
# Maps names of concrete functions in the object to names of wrapped
# functions. When writing the SavedFunction protos, the names of the
# wrapped functions should be used in place of the original functions.
self.function_name_map = {
compat.as_text(original.name): compat.as_text(wrapped.name)
for original, wrapped in self.wrapped_functions.items()}
# Also add `Function`s as nodes.
nodes_without_functions = list(self.nodes)
seen_function_names = set()
@ -287,6 +303,14 @@ class _SaveableView(object):
if (tensor_util.is_tensor(capture) and
capture.dtype not in _UNCOPIABLE_DTYPES and
capture not in self.captured_tensor_node_ids):
if hasattr(capture, "_cached_variable"):
if concrete_function not in self.wrapped_functions:
wrapped = self.wrapped_functions[concrete_function] = (
function_serialization.wrap_cached_variables(
concrete_function))
self.function_name_map[compat.as_text(concrete_function.name)] = (
compat.as_text(wrapped.name))
continue
capture_constant_value = tensor_util.constant_value(capture)
if capture_constant_value is None:
bad_functions.append(concrete_function)
@ -302,7 +326,8 @@ class _SaveableView(object):
resource_map[capture] = copied_tensor
self.concrete_functions = [
x for x in self.concrete_functions if x not in bad_functions
self.wrapped_functions.get(x, x) for x in self.concrete_functions
if x not in bad_functions
]
return object_map, resource_map, asset_info
@ -465,8 +490,8 @@ def _generate_signatures(signature_functions, resource_map):
mapped_inputs, exterior_argument_placeholders = (
_map_function_arguments_to_created_inputs(argument_inputs,
signature_key, function.name))
outputs = _call_function_with_mapped_captures(function, mapped_inputs,
resource_map)
outputs = _call_function_with_mapped_captures(
function, mapped_inputs, resource_map)
signatures[signature_key] = signature_def_utils.build_signature_def(
_tensor_dict_to_tensorinfo(exterior_argument_placeholders),
_tensor_dict_to_tensorinfo(outputs),
@ -657,17 +682,20 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
coder = nested_structure_coder.StructureCoder()
for concrete_function in saveable_view.concrete_functions:
name = compat.as_text(concrete_function.name)
name = saveable_view.function_name_map.get(name, name)
serialized = function_serialization.serialize_concrete_function(
concrete_function, saveable_view.captured_tensor_node_ids, coder)
if serialized is not None:
proto.concrete_functions[concrete_function.name].CopyFrom(serialized)
proto.concrete_functions[name].CopyFrom(serialized)
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index)
_write_object_proto(obj, obj_proto, asset_file_def_index,
saveable_view.function_name_map)
return proto
def _write_object_proto(obj, proto, asset_file_def_index):
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
"""Saves an object into SavedObject proto."""
if isinstance(obj, tracking.Asset):
proto.asset.SetInParent()
@ -684,10 +712,12 @@ def _write_object_proto(obj, proto, asset_file_def_index):
proto.variable.aggregation = obj.aggregation.value
proto.variable.shape.CopyFrom(obj.shape.as_proto())
elif isinstance(obj, def_function.Function):
proto.function.CopyFrom(function_serialization.serialize_function(obj))
proto.function.CopyFrom(function_serialization.serialize_function(
obj, function_name_map))
elif isinstance(obj, defun.ConcreteFunction):
proto.bare_concrete_function.CopyFrom(
function_serialization.serialize_bare_concrete_function(obj))
function_serialization.serialize_bare_concrete_function(
obj, function_name_map))
elif isinstance(obj, _CapturedConstant):
proto.constant.operation = obj.graph_tensor.op.name
elif isinstance(obj, tracking.CapturableResource):
@ -924,7 +954,8 @@ def save(obj, export_dir, signatures=None, options=None):
signatures = signature_serialization.find_function_to_export(
checkpoint_graph_view)
signatures = signature_serialization.canonicalize_signatures(signatures)
signatures, wrapped_functions = (
signature_serialization.canonicalize_signatures(signatures))
signature_serialization.validate_saveable_view(checkpoint_graph_view)
signature_map = signature_serialization.create_signature_map(signatures)
checkpoint_graph_view.add_object(
@ -936,7 +967,7 @@ def save(obj, export_dir, signatures=None, options=None):
# Note we run this twice since, while constructing the view the first time
# there can be side effects of creating variables.
_ = _SaveableView(checkpoint_graph_view)
saveable_view = _SaveableView(checkpoint_graph_view)
saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than

View File

@ -472,6 +472,22 @@ class SaveTest(test.TestCase):
for node in f.node_def:
assert_correct_number_of_output_shapes(node)
def test_save_cached_variable(self):
with ops.Graph().as_default(), session_lib.Session() as session:
obj = tracking.AutoTrackable()
obj.v = variables.Variable(2., caching_device=lambda op: op.device)
obj.w = variables.Variable(3.)
session.run([obj.v.initializer, obj.w.initializer])
@def_function.function(input_signature=[])
def f():
return obj.v + obj.w
obj.f = f
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(obj, save_dir, signatures=obj.f)
self.assertAllClose({"output_0": 5}, _import_and_infer(save_dir, {}))
class SavingOptionsTest(test.TestCase):

View File

@ -23,6 +23,7 @@ from tensorflow.python.eager import function as defun
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training.tracking import base
@ -95,19 +96,24 @@ def find_function_to_export(saveable_view):
def canonicalize_signatures(signatures):
"""Converts `signatures` into a dictionary of concrete functions."""
if signatures is None:
return {}
return {}, {}
if not isinstance(signatures, collections_abc.Mapping):
signatures = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
concrete_signatures = {}
wrapped_functions = {}
for signature_key, function in signatures.items():
signature_function = _get_signature(function)
original_function = signature_function = _get_signature(function)
if signature_function is None:
raise ValueError(
("Expected a TensorFlow function to generate a signature for, but "
"got {}. Only `tf.functions` with an input signature or "
"concrete functions can be used as a signature.").format(function))
wrapped_functions[original_function] = signature_function = (
wrapped_functions.get(original_function) or
function_serialization.wrap_cached_variables(original_function))
_validate_inputs(signature_function)
# Re-wrap the function so that it returns a dictionary of Tensors. This
@ -141,7 +147,7 @@ def canonicalize_signatures(signatures):
# pylint: enable=protected-access
concrete_signatures[signature_key] = final_concrete
# pylint: enable=cell-var-from-loop
return concrete_signatures
return concrete_signatures, wrapped_functions
def _is_flat(sequence):