When saving to SavedModel, functions that capture cached values are now wrapped with functions that capture the original variable.
PiperOrigin-RevId: 292041916 Change-Id: I69a3174c52d526ad0ec48d7cd533ff95cb44747a
This commit is contained in:
parent
b862920996
commit
7705f70c14
tensorflow/python
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import variable_pb2
|
||||
@ -347,6 +348,7 @@ class BaseResourceVariable(variables.VariableV1):
|
||||
cached_value=None,
|
||||
save_slice_info=None,
|
||||
handle_deleter=None,
|
||||
caching_device=None,
|
||||
**unused_kwargs):
|
||||
"""Creates a variable from a handle.
|
||||
|
||||
@ -387,6 +389,11 @@ class BaseResourceVariable(variables.VariableV1):
|
||||
save_slice_info: Metadata for variable partitioning.
|
||||
handle_deleter: EagerResourceDeleter responsible for cleaning up the
|
||||
handle.
|
||||
caching_device: Optional device string or function describing where the
|
||||
Variable should be cached for reading. Defaults to the Variable's
|
||||
device. If not `None`, caches on another device. Typical use is to
|
||||
cache on the device where the Ops using the Variable reside, to
|
||||
deduplicate copying through `Switch` and other conditional statements.
|
||||
"""
|
||||
with ops.init_scope():
|
||||
self._in_graph_mode = not context.executing_eagerly()
|
||||
@ -401,6 +408,7 @@ class BaseResourceVariable(variables.VariableV1):
|
||||
self._initializer_op = initializer_op
|
||||
self._is_initialized_op = is_initialized_op
|
||||
self._graph_element = graph_element
|
||||
self._caching_device = caching_device
|
||||
self._cached_value = cached_value
|
||||
self._distribute_strategy = distribute_strategy
|
||||
# Store the graph key so optimizers know how to only retrieve variables from
|
||||
@ -612,9 +620,19 @@ class BaseResourceVariable(variables.VariableV1):
|
||||
|
||||
def _read_variable_op(self):
|
||||
variable_accessed(self)
|
||||
result = gen_resource_variable_ops.read_variable_op(self._handle,
|
||||
self._dtype)
|
||||
_maybe_set_handle_data(self._dtype, self._handle, result)
|
||||
|
||||
def read_and_set_handle():
|
||||
result = gen_resource_variable_ops.read_variable_op(self._handle,
|
||||
self._dtype)
|
||||
_maybe_set_handle_data(self._dtype, self._handle, result)
|
||||
return result
|
||||
|
||||
if getattr(self, "_caching_device", None) is not None:
|
||||
with ops.colocate_with(None, ignore_existing=True):
|
||||
with ops.device(self._caching_device):
|
||||
result = read_and_set_handle()
|
||||
else:
|
||||
result = read_and_set_handle()
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# Note that if a control flow context is active the input of the read op
|
||||
@ -1614,6 +1632,12 @@ class ResourceVariable(BaseResourceVariable):
|
||||
_maybe_set_handle_data(dtype, handle, cached_value)
|
||||
else:
|
||||
cached_value = None
|
||||
|
||||
if cached_value is not None:
|
||||
# Store the variable object so that the original variable can be
|
||||
# accessed to generate functions that are compatible with SavedModel.
|
||||
cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# Eager variables are only added to collections if they are part of an
|
||||
# eager variable store (otherwise in an interactive session they would
|
||||
@ -1629,7 +1653,7 @@ class ResourceVariable(BaseResourceVariable):
|
||||
name=name, unique_id=unique_id, handle_name=handle_name,
|
||||
graph_element=graph_element, initial_value=initial_value,
|
||||
initializer_op=initializer_op, is_initialized_op=is_initialized_op,
|
||||
cached_value=cached_value)
|
||||
cached_value=cached_value, caching_device=caching_device)
|
||||
|
||||
def _init_from_proto(self, variable_def, import_scope=None):
|
||||
"""Initializes from `VariableDef` proto."""
|
||||
|
@ -19,8 +19,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _serialize_function_spec(function_spec, coder):
|
||||
@ -72,17 +75,19 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
|
||||
return concrete_function_proto
|
||||
|
||||
|
||||
def serialize_bare_concrete_function(concrete_function):
|
||||
def serialize_bare_concrete_function(concrete_function, name_map):
|
||||
"""Build a SavedBareConcreteFunction."""
|
||||
# pylint: disable=protected-access
|
||||
name = name_map.get(compat.as_text(concrete_function.name),
|
||||
concrete_function.name)
|
||||
return saved_object_graph_pb2.SavedBareConcreteFunction(
|
||||
concrete_function_name=concrete_function.name,
|
||||
concrete_function_name=name,
|
||||
allowed_positional_arguments=concrete_function._num_positional_args,
|
||||
argument_keywords=concrete_function._arg_keywords)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def serialize_function(function):
|
||||
def serialize_function(function, name_map):
|
||||
"""Build a SavedFunction proto."""
|
||||
coder = nested_structure_coder.StructureCoder()
|
||||
proto = saved_object_graph_pb2.SavedFunction()
|
||||
@ -92,5 +97,65 @@ def serialize_function(function):
|
||||
all_concrete_functions = \
|
||||
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
|
||||
for concrete_function in all_concrete_functions:
|
||||
proto.concrete_functions.append(concrete_function.name)
|
||||
proto.concrete_functions.append(
|
||||
name_map.get(compat.as_text(concrete_function.name),
|
||||
concrete_function.name))
|
||||
return proto
|
||||
|
||||
|
||||
def wrap_cached_variables(concrete_function):
|
||||
"""Wraps the concrete function if it uses cached read tensors.
|
||||
|
||||
This function creates a new concrete function that captures variables
|
||||
instead of the cached read tensors.
|
||||
|
||||
Args:
|
||||
concrete_function: A Concrete function that maybe captures cached read
|
||||
tensors.
|
||||
|
||||
Returns:
|
||||
A concrete function that wraps the original concrete function, which
|
||||
captures variables instead. If the original function did not capture any
|
||||
cached values, then the function is not wrapped and the original object is
|
||||
returned.
|
||||
"""
|
||||
outer_graph = func_graph_module.FuncGraph(
|
||||
"{}_no_cache".format(concrete_function.graph.name))
|
||||
captures = concrete_function.graph._captures # pylint: disable=protected-access
|
||||
mapped_captures = None
|
||||
remapped_captures = {}
|
||||
|
||||
# Update the external captures to use read tensors generated in the outer
|
||||
# graph.
|
||||
with outer_graph.as_default():
|
||||
for capture, placeholder in concrete_function.graph.captures:
|
||||
cached_variable = getattr(capture, "_cached_variable", None)
|
||||
if cached_variable is None:
|
||||
continue
|
||||
cached_variable = cached_variable()
|
||||
new_cached_value = cached_variable.read_value()
|
||||
remapped_captures[id(capture)] = captures[id(capture)]
|
||||
captures[id(capture)] = (new_cached_value, placeholder)
|
||||
mapped_captures = True
|
||||
|
||||
if not mapped_captures:
|
||||
return concrete_function
|
||||
|
||||
inner_concrete = defun.ConcreteFunction(concrete_function.graph)
|
||||
|
||||
def wrap_function(*args):
|
||||
return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access
|
||||
|
||||
args = nest.flatten(concrete_function.structured_input_signature,
|
||||
expand_composites=True)
|
||||
func_graph_module.func_graph_from_py_func(
|
||||
None, wrap_function, args=tuple(args), kwargs={},
|
||||
func_graph=outer_graph)
|
||||
fn = defun.ConcreteFunction(outer_graph)
|
||||
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
|
||||
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
|
||||
|
||||
# Return the captures to their original values
|
||||
for key, capture in remapped_captures.items():
|
||||
captures[key] = capture
|
||||
return fn
|
||||
|
@ -27,6 +27,7 @@ import weakref
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -1826,6 +1827,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
rt = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
self.assertAllEqual(imported2.f(rt), [[2, 3], [4]])
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name="ReloadOnce", cycles=1),
|
||||
@ -1943,6 +1945,53 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
"object has an attribute named a, which is reserved."):
|
||||
save.save(root, path)
|
||||
|
||||
def test_save_cached_variable(self):
|
||||
with ops.Graph().as_default(), session_lib.Session() as session:
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.v = variables.Variable(2., caching_device=lambda op: op.device)
|
||||
obj.w = variables.Variable(3.)
|
||||
session.run([obj.v.initializer, obj.w.initializer])
|
||||
|
||||
@def_function.function
|
||||
def total():
|
||||
return obj.v + obj.w
|
||||
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec([])])
|
||||
def wrapped_total(x):
|
||||
return total() + x
|
||||
|
||||
@def_function.function
|
||||
def increment_v(x):
|
||||
obj.v.assign_add(x)
|
||||
|
||||
session.run(increment_v(constant_op.constant(3.))) # generate signatures
|
||||
self.assertAllClose(8, total())
|
||||
self.assertAllClose(13, wrapped_total(constant_op.constant(5.)))
|
||||
|
||||
obj.total = total
|
||||
obj.wrapped_total = wrapped_total.get_concrete_function()
|
||||
obj.increment_v = increment_v
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(obj, save_dir, signatures=total.get_concrete_function())
|
||||
imported = load.load(save_dir)
|
||||
session.run(variables.global_variables_initializer())
|
||||
self.assertAllClose(8, imported.total())
|
||||
session.run(imported.increment_v(4))
|
||||
self.assertAllClose(12, imported.total())
|
||||
self.assertAllClose(15, imported.wrapped_total(constant_op.constant(3.)))
|
||||
self.assertAllClose({"output_0": 12},
|
||||
imported.signatures["serving_default"]())
|
||||
|
||||
# Try loading and running the function in eager mode
|
||||
imported = load.load(save_dir)
|
||||
self.assertAllClose(8, imported.total())
|
||||
imported.increment_v(5)
|
||||
self.assertAllClose(13, imported.total())
|
||||
self.assertAllClose(13.5, imported.wrapped_total(constant_op.constant(.5)))
|
||||
self.assertAllClose({"output_0": 13},
|
||||
imported.signatures["serving_default"]())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -159,7 +159,14 @@ class _SaveableView(object):
|
||||
ignored.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_view):
|
||||
def __init__(self, checkpoint_view, wrapped_functions=None):
|
||||
"""Initializes a SaveableView.
|
||||
|
||||
Args:
|
||||
checkpoint_view: A GraphView object.
|
||||
wrapped_functions: Dictionary that maps concrete functions to functions
|
||||
that do not capture cached variable values.
|
||||
"""
|
||||
self.checkpoint_view = checkpoint_view
|
||||
trackable_objects, node_ids, slot_variables = (
|
||||
self.checkpoint_view.objects_ids_and_slot_variables())
|
||||
@ -169,6 +176,15 @@ class _SaveableView(object):
|
||||
self.slot_variables = slot_variables
|
||||
self.concrete_functions = []
|
||||
|
||||
# Maps functions -> wrapped functions that capture variables
|
||||
self.wrapped_functions = wrapped_functions or {}
|
||||
# Maps names of concrete functions in the object to names of wrapped
|
||||
# functions. When writing the SavedFunction protos, the names of the
|
||||
# wrapped functions should be used in place of the original functions.
|
||||
self.function_name_map = {
|
||||
compat.as_text(original.name): compat.as_text(wrapped.name)
|
||||
for original, wrapped in self.wrapped_functions.items()}
|
||||
|
||||
# Also add `Function`s as nodes.
|
||||
nodes_without_functions = list(self.nodes)
|
||||
seen_function_names = set()
|
||||
@ -287,6 +303,14 @@ class _SaveableView(object):
|
||||
if (tensor_util.is_tensor(capture) and
|
||||
capture.dtype not in _UNCOPIABLE_DTYPES and
|
||||
capture not in self.captured_tensor_node_ids):
|
||||
if hasattr(capture, "_cached_variable"):
|
||||
if concrete_function not in self.wrapped_functions:
|
||||
wrapped = self.wrapped_functions[concrete_function] = (
|
||||
function_serialization.wrap_cached_variables(
|
||||
concrete_function))
|
||||
self.function_name_map[compat.as_text(concrete_function.name)] = (
|
||||
compat.as_text(wrapped.name))
|
||||
continue
|
||||
capture_constant_value = tensor_util.constant_value(capture)
|
||||
if capture_constant_value is None:
|
||||
bad_functions.append(concrete_function)
|
||||
@ -302,7 +326,8 @@ class _SaveableView(object):
|
||||
resource_map[capture] = copied_tensor
|
||||
|
||||
self.concrete_functions = [
|
||||
x for x in self.concrete_functions if x not in bad_functions
|
||||
self.wrapped_functions.get(x, x) for x in self.concrete_functions
|
||||
if x not in bad_functions
|
||||
]
|
||||
return object_map, resource_map, asset_info
|
||||
|
||||
@ -465,8 +490,8 @@ def _generate_signatures(signature_functions, resource_map):
|
||||
mapped_inputs, exterior_argument_placeholders = (
|
||||
_map_function_arguments_to_created_inputs(argument_inputs,
|
||||
signature_key, function.name))
|
||||
outputs = _call_function_with_mapped_captures(function, mapped_inputs,
|
||||
resource_map)
|
||||
outputs = _call_function_with_mapped_captures(
|
||||
function, mapped_inputs, resource_map)
|
||||
signatures[signature_key] = signature_def_utils.build_signature_def(
|
||||
_tensor_dict_to_tensorinfo(exterior_argument_placeholders),
|
||||
_tensor_dict_to_tensorinfo(outputs),
|
||||
@ -657,17 +682,20 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
|
||||
|
||||
coder = nested_structure_coder.StructureCoder()
|
||||
for concrete_function in saveable_view.concrete_functions:
|
||||
name = compat.as_text(concrete_function.name)
|
||||
name = saveable_view.function_name_map.get(name, name)
|
||||
serialized = function_serialization.serialize_concrete_function(
|
||||
concrete_function, saveable_view.captured_tensor_node_ids, coder)
|
||||
if serialized is not None:
|
||||
proto.concrete_functions[concrete_function.name].CopyFrom(serialized)
|
||||
proto.concrete_functions[name].CopyFrom(serialized)
|
||||
|
||||
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
||||
_write_object_proto(obj, obj_proto, asset_file_def_index)
|
||||
_write_object_proto(obj, obj_proto, asset_file_def_index,
|
||||
saveable_view.function_name_map)
|
||||
return proto
|
||||
|
||||
|
||||
def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
"""Saves an object into SavedObject proto."""
|
||||
if isinstance(obj, tracking.Asset):
|
||||
proto.asset.SetInParent()
|
||||
@ -684,10 +712,12 @@ def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
proto.variable.aggregation = obj.aggregation.value
|
||||
proto.variable.shape.CopyFrom(obj.shape.as_proto())
|
||||
elif isinstance(obj, def_function.Function):
|
||||
proto.function.CopyFrom(function_serialization.serialize_function(obj))
|
||||
proto.function.CopyFrom(function_serialization.serialize_function(
|
||||
obj, function_name_map))
|
||||
elif isinstance(obj, defun.ConcreteFunction):
|
||||
proto.bare_concrete_function.CopyFrom(
|
||||
function_serialization.serialize_bare_concrete_function(obj))
|
||||
function_serialization.serialize_bare_concrete_function(
|
||||
obj, function_name_map))
|
||||
elif isinstance(obj, _CapturedConstant):
|
||||
proto.constant.operation = obj.graph_tensor.op.name
|
||||
elif isinstance(obj, tracking.CapturableResource):
|
||||
@ -924,7 +954,8 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
signatures = signature_serialization.find_function_to_export(
|
||||
checkpoint_graph_view)
|
||||
|
||||
signatures = signature_serialization.canonicalize_signatures(signatures)
|
||||
signatures, wrapped_functions = (
|
||||
signature_serialization.canonicalize_signatures(signatures))
|
||||
signature_serialization.validate_saveable_view(checkpoint_graph_view)
|
||||
signature_map = signature_serialization.create_signature_map(signatures)
|
||||
checkpoint_graph_view.add_object(
|
||||
@ -936,7 +967,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
# Note we run this twice since, while constructing the view the first time
|
||||
# there can be side effects of creating variables.
|
||||
_ = _SaveableView(checkpoint_graph_view)
|
||||
saveable_view = _SaveableView(checkpoint_graph_view)
|
||||
saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
|
||||
|
||||
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
||||
# compatible (no sessions) and share it with this export API rather than
|
||||
|
@ -472,6 +472,22 @@ class SaveTest(test.TestCase):
|
||||
for node in f.node_def:
|
||||
assert_correct_number_of_output_shapes(node)
|
||||
|
||||
def test_save_cached_variable(self):
|
||||
with ops.Graph().as_default(), session_lib.Session() as session:
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.v = variables.Variable(2., caching_device=lambda op: op.device)
|
||||
obj.w = variables.Variable(3.)
|
||||
session.run([obj.v.initializer, obj.w.initializer])
|
||||
|
||||
@def_function.function(input_signature=[])
|
||||
def f():
|
||||
return obj.v + obj.w
|
||||
|
||||
obj.f = f
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(obj, save_dir, signatures=obj.f)
|
||||
self.assertAllClose({"output_0": 5}, _import_and_infer(save_dir, {}))
|
||||
|
||||
|
||||
class SavingOptionsTest(test.TestCase):
|
||||
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.saved_model import function_serialization
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training.tracking import base
|
||||
@ -95,19 +96,24 @@ def find_function_to_export(saveable_view):
|
||||
def canonicalize_signatures(signatures):
|
||||
"""Converts `signatures` into a dictionary of concrete functions."""
|
||||
if signatures is None:
|
||||
return {}
|
||||
return {}, {}
|
||||
if not isinstance(signatures, collections_abc.Mapping):
|
||||
signatures = {
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
|
||||
concrete_signatures = {}
|
||||
wrapped_functions = {}
|
||||
for signature_key, function in signatures.items():
|
||||
signature_function = _get_signature(function)
|
||||
original_function = signature_function = _get_signature(function)
|
||||
|
||||
if signature_function is None:
|
||||
raise ValueError(
|
||||
("Expected a TensorFlow function to generate a signature for, but "
|
||||
"got {}. Only `tf.functions` with an input signature or "
|
||||
"concrete functions can be used as a signature.").format(function))
|
||||
|
||||
wrapped_functions[original_function] = signature_function = (
|
||||
wrapped_functions.get(original_function) or
|
||||
function_serialization.wrap_cached_variables(original_function))
|
||||
_validate_inputs(signature_function)
|
||||
|
||||
# Re-wrap the function so that it returns a dictionary of Tensors. This
|
||||
@ -141,7 +147,7 @@ def canonicalize_signatures(signatures):
|
||||
# pylint: enable=protected-access
|
||||
concrete_signatures[signature_key] = final_concrete
|
||||
# pylint: enable=cell-var-from-loop
|
||||
return concrete_signatures
|
||||
return concrete_signatures, wrapped_functions
|
||||
|
||||
|
||||
def _is_flat(sequence):
|
||||
|
Loading…
Reference in New Issue
Block a user