Skeleton for PolymorphicFunction serialization
Missing things like variables, function/argument names, support for arguments that aren't a flat list of Tensors, and many other things. But it does manage to save, restore, and call a function. Starts saving a bit of extra metadata when a new function trace is created. Since this does not have to be computed each time the function is called, I expect the performance impact to be minimal. PiperOrigin-RevId: 225072712
This commit is contained in:
parent
40345bd2c3
commit
2a87c2df92
@ -242,6 +242,7 @@ class PolymorphicFunction(object):
|
||||
raise NotImplementedError()
|
||||
self._created_variables = None
|
||||
self._stateful_fn = None
|
||||
self._stateless_fn = None
|
||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||
self._name = name
|
||||
|
||||
@ -382,6 +383,26 @@ class PolymorphicFunction(object):
|
||||
|
||||
return initialize_variables.get_concrete_function()
|
||||
|
||||
@property
|
||||
def _cached_input_signatures(self):
|
||||
"""All input signatures used to call this PolymorphicFunction."""
|
||||
seen = set()
|
||||
# Preserves signature ordering rather than returning a set() so that we
|
||||
# don't need to re-sort signatures later to work around Python 2's set
|
||||
# nondeterminism.
|
||||
# pylint: disable=protected-access
|
||||
concrete_functions = []
|
||||
if self._stateful_fn:
|
||||
concrete_functions.extend(self._stateful_fn._function_cache.values())
|
||||
if self._stateless_fn:
|
||||
concrete_functions.extend(self._stateless_fn._function_cache.values())
|
||||
for concrete_function in concrete_functions:
|
||||
signature = concrete_function._python_call_signature
|
||||
if signature not in seen:
|
||||
yield signature
|
||||
seen.add(signature)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def get_concrete_function(self, *args, **kwargs):
|
||||
"""Returns a `Function` object specialized to inputs and execution context.
|
||||
|
||||
|
@ -238,6 +238,25 @@ class DefFunctionTest(test.TestCase):
|
||||
concrete = compute.get_concrete_function(
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||
self.assertAllClose(4., concrete(constant_op.constant(2.)))
|
||||
input_signature, = compute._cached_input_signatures
|
||||
self.assertEqual(
|
||||
tuple(input_signature),
|
||||
(tensor_spec.TensorSpec(None, dtypes.float32),))
|
||||
|
||||
def test_serialization_signature_cache(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
||||
f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
|
||||
f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))
|
||||
self.assertEqual(
|
||||
set(f._cached_input_signatures),
|
||||
set(((tensor_spec.TensorSpec([1, 2], dtypes.float32),
|
||||
tensor_spec.TensorSpec([1], dtypes.float32)),
|
||||
(tensor_spec.TensorSpec([1, 3], dtypes.int32),
|
||||
tensor_spec.TensorSpec([1], dtypes.int32)))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -748,6 +748,19 @@ class Function(object):
|
||||
return ret
|
||||
|
||||
|
||||
class UnknownArgument(object):
|
||||
"""Signifies an argument which is not currently handled."""
|
||||
pass
|
||||
|
||||
|
||||
def _encode_arg_for_serialization(arg):
|
||||
"""A representation for this argument, for serializing signatures."""
|
||||
if isinstance(arg, ops.Tensor):
|
||||
return tensor_spec.TensorSpec(arg.shape, arg.dtype)
|
||||
else:
|
||||
return UnknownArgument()
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
|
||||
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
|
||||
|
||||
@ -1163,6 +1176,14 @@ class PolymorphicFunction(object):
|
||||
autograph=self._autograph,
|
||||
arg_names=arg_names),
|
||||
self._function_attributes)
|
||||
if self._input_signature:
|
||||
python_call_signature = self._input_signature
|
||||
else:
|
||||
python_call_signature = tuple(
|
||||
_encode_arg_for_serialization(arg) for arg in args)
|
||||
# Save information about non-Tensor arguments with the concrete
|
||||
# function. Used to serialize PolymorphicFunctions.
|
||||
graph_function._python_call_signature = python_call_signature # pylint: disable=protected-access
|
||||
self._function_cache[cache_key] = graph_function
|
||||
return graph_function, args, kwargs
|
||||
|
||||
|
@ -993,17 +993,18 @@ def _call(sig, *inputs, **kwargs):
|
||||
name = kwargs.pop("name", None)
|
||||
g = ops.get_default_graph()
|
||||
func_name = sig.name
|
||||
if name is None:
|
||||
name = func_name
|
||||
attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
|
||||
output_types = [dtypes.DType(x.type) for x in sig.output_arg]
|
||||
with ops.name_scope(name, func_name, inputs) as name:
|
||||
op = g.create_op(
|
||||
func_name,
|
||||
list(inputs),
|
||||
output_types,
|
||||
name=name,
|
||||
attrs=attrs,
|
||||
op_def=sig,
|
||||
compute_shapes=False)
|
||||
op = g.create_op(
|
||||
func_name,
|
||||
list(inputs),
|
||||
output_types,
|
||||
name=name,
|
||||
attrs=attrs,
|
||||
op_def=sig,
|
||||
compute_shapes=False)
|
||||
if op.outputs:
|
||||
if len(op.outputs) == 1:
|
||||
ret = op.outputs[0]
|
||||
@ -1046,12 +1047,13 @@ def _from_definition(fdef, grad_func=None):
|
||||
c_func = c_api.TF_FunctionImportFunctionDef(serialized)
|
||||
result._c_func = c_api_util.ScopedTFFunction(c_func)
|
||||
result._extra_inputs = []
|
||||
result._op_def = fdef.signature
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _from_library(lib):
|
||||
def from_library(lib):
|
||||
"""Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
|
||||
|
||||
This method handles assigning the correct gradient functions to each
|
||||
|
@ -1287,7 +1287,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
gradients_impl.gradients([f1, f2, f3, f4], c)
|
||||
|
||||
library = g.as_graph_def().library
|
||||
new_funcs = function._from_library(library)
|
||||
new_funcs = function.from_library(library)
|
||||
|
||||
def CheckNewFunc(func):
|
||||
new_func = [f for f in new_funcs if f.name == func.name]
|
||||
@ -1303,7 +1303,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
|
||||
def testFromLibraryEmptyLib(self):
|
||||
library = function_pb2.FunctionDefLibrary()
|
||||
self.assertEqual(len(function._from_library(library)), 0)
|
||||
self.assertEqual(len(function.from_library(library)), 0)
|
||||
|
||||
def testFromLibraryMissingFuncDef(self):
|
||||
|
||||
@ -1327,7 +1327,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
|
||||
function._from_library(library)
|
||||
function.from_library(library)
|
||||
|
||||
# Create invalid function def that is missing F1 function def
|
||||
library = function_pb2.FunctionDefLibrary()
|
||||
@ -1337,7 +1337,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
|
||||
function._from_library(library)
|
||||
function.from_library(library)
|
||||
|
||||
def testFromLibraryCyclicGradFuncs(self):
|
||||
|
||||
@ -1366,7 +1366,7 @@ class FunctionsFromProtos(test.TestCase):
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
|
||||
function._from_library(library)
|
||||
function.from_library(library)
|
||||
|
||||
def testExperimentalAttrs(self):
|
||||
|
||||
|
@ -442,11 +442,9 @@ def import_graph_def(graph_def,
|
||||
_ProcessNewOps(graph)
|
||||
|
||||
if graph_def.library and graph_def.library.function:
|
||||
# pylint: disable=protected-access
|
||||
functions = function._from_library(graph_def.library)
|
||||
functions = function.from_library(graph_def.library)
|
||||
for f in functions:
|
||||
f.add_to_graph(graph)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Treat input mappings that don't appear in the graph as an error, because
|
||||
# they are likely to be due to a typo.
|
||||
|
@ -287,7 +287,7 @@ py_library(
|
||||
deps = [
|
||||
":builder",
|
||||
":constants",
|
||||
":loader",
|
||||
":function_serialization",
|
||||
":saved_object_graph_py",
|
||||
":signature_constants",
|
||||
":signature_def_utils",
|
||||
@ -295,15 +295,20 @@ py_library(
|
||||
":utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
],
|
||||
)
|
||||
@ -330,8 +335,12 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":constants",
|
||||
":function_deserialization",
|
||||
":loader",
|
||||
":saved_object_graph_py",
|
||||
":utils",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
@ -345,10 +354,34 @@ py_test(
|
||||
deps = [
|
||||
":load",
|
||||
":save",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "function_serialization",
|
||||
srcs = [
|
||||
"function_serialization.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saved_object_graph_py",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "function_deserialization",
|
||||
srcs = [
|
||||
"function_deserialization.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/python/eager:def_function"],
|
||||
)
|
||||
|
46
tensorflow/python/saved_model/function_deserialization.py
Normal file
46
tensorflow/python/saved_model/function_deserialization.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tools for deserializing PolymorphicFunctions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
|
||||
|
||||
def recreate_polymorphic_function(
|
||||
saved_polymorphic_function, defined_functions):
|
||||
"""Creates a PolymorphicFunction which runs restored function definitions."""
|
||||
@def_function.function
|
||||
def restored_function(*args):
|
||||
"""Calls a restored function."""
|
||||
# Try calling each function, return a value from the first one whose
|
||||
# signature matches.
|
||||
# TODO(allenl): Consider re-populating the function cache directly.
|
||||
# TODO(allenl): Functions saved with input_signatures should revive with
|
||||
# input_signatures.
|
||||
for monomorphic_function in saved_polymorphic_function.monomorphic_function:
|
||||
try:
|
||||
# TODO(allenl): Passing an explicit name here prevents invalid name
|
||||
# errors. We should replace this with something based on the actual
|
||||
# Python function name.
|
||||
return defined_functions[monomorphic_function.concrete_function](
|
||||
*args, name="imported_function")
|
||||
except ValueError:
|
||||
continue
|
||||
raise AssertionError(
|
||||
"Could not find matching function to call for arguments: %s" % (args,))
|
||||
return restored_function
|
71
tensorflow/python/saved_model/function_serialization.py
Normal file
71
tensorflow/python/saved_model/function_serialization.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tools for serializing PolymorphicFunctions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun_lib
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
|
||||
|
||||
def _serialize_polymorphic_function(function):
|
||||
"""Represents a PolymorphicFunction in a SavedModel.
|
||||
|
||||
Adds `function`'s concrete functions to the current graph.
|
||||
|
||||
Args:
|
||||
function: A `PolymorphicFunction` to serialize.
|
||||
|
||||
Returns:
|
||||
An unserialized `SavedPolymorphicFunction` protocol buffer object.
|
||||
"""
|
||||
monomorphic_functions = []
|
||||
for signature in function._cached_input_signatures: # pylint: disable=protected-access
|
||||
if any(isinstance(arg, defun_lib.UnknownArgument) for arg in signature):
|
||||
continue
|
||||
concrete_function = function.get_concrete_function(*signature)
|
||||
concrete_function.add_to_graph()
|
||||
monomorphic_functions.append(
|
||||
saved_object_graph_pb2.SavedMonomorphicFunction(
|
||||
concrete_function=concrete_function.name))
|
||||
return saved_object_graph_pb2.SavedPolymorphicFunction(
|
||||
monomorphic_function=monomorphic_functions)
|
||||
|
||||
|
||||
def add_polymorphic_functions_to_object_graph_proto(
|
||||
checkpointable_objects, saved_object_graph):
|
||||
"""Finds PolymorphicFunctions attached to objects and saves them."""
|
||||
existing_objects = list(zip(checkpointable_objects, saved_object_graph.nodes))
|
||||
for obj, obj_proto in existing_objects:
|
||||
for attribute_name in dir(obj):
|
||||
try:
|
||||
attribute_value = getattr(obj, attribute_name, None)
|
||||
except: # pylint: disable=bare-except
|
||||
# We really don't want to throw an exception just because some object's
|
||||
# attribute accessor is broken.
|
||||
attribute_value = None
|
||||
# TODO(allenl): Consider de-duplicating functions which are referenced
|
||||
# from multiple attributes.
|
||||
if isinstance(attribute_value, def_function.PolymorphicFunction):
|
||||
function_node_id = len(saved_object_graph.nodes)
|
||||
function_node = saved_object_graph.nodes.add()
|
||||
function_node.function.CopyFrom(
|
||||
_serialize_polymorphic_function(attribute_value))
|
||||
reference = obj_proto.children.add()
|
||||
reference.node_id = function_node_id
|
||||
reference.local_name = attribute_name
|
@ -20,8 +20,10 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.framework import function as function_lib
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import function_deserialization
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
@ -33,9 +35,17 @@ class _Loader(object):
|
||||
"""Helper class to load an object-based SavedModel."""
|
||||
|
||||
def __init__(self, object_graph_proto, saved_model_proto, export_dir):
|
||||
self._asset_file_def = saved_model_proto.meta_graphs[0].asset_file_def
|
||||
meta_graph = saved_model_proto.meta_graphs[0]
|
||||
self._asset_file_def = meta_graph.asset_file_def
|
||||
self._proto = object_graph_proto
|
||||
self._export_dir = export_dir
|
||||
self._defined_functions = {}
|
||||
for defined_function in function_lib.from_library(
|
||||
meta_graph.graph_def.library):
|
||||
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
|
||||
# happens when loaded names collide with existing names.
|
||||
defined_function.add_to_graph(None)
|
||||
self._defined_functions[defined_function.name] = defined_function
|
||||
self._load_all()
|
||||
|
||||
def _load_all(self):
|
||||
@ -52,6 +62,7 @@ class _Loader(object):
|
||||
factory = {
|
||||
"user_object": lambda: self._recreate_user_object(proto.user_object),
|
||||
"asset": lambda: self._recreate_asset(proto.asset),
|
||||
"function": lambda: self._recreate_function(proto.function)
|
||||
}
|
||||
kind = proto.WhichOneof("kind")
|
||||
if kind not in factory:
|
||||
@ -68,6 +79,10 @@ class _Loader(object):
|
||||
self._asset_file_def[proto.asset_file_def_index].filename)
|
||||
return tracking.TrackableAsset(filename)
|
||||
|
||||
def _recreate_function(self, proto):
|
||||
return function_deserialization.recreate_polymorphic_function(
|
||||
proto, self._defined_functions)
|
||||
|
||||
|
||||
def _load_saved_object_graph_proto(filename):
|
||||
with file_io.FileIO(filename, "rb") as f:
|
||||
@ -92,5 +107,4 @@ def load(export_dir):
|
||||
raise NotImplementedError(
|
||||
"Currently only SavedModels exported with `tf.saved_model.save` may be "
|
||||
"imported. Other SavedModels may eventually be supported via load().")
|
||||
# TODO(allenl): load functions from the SavedModel into the eager context
|
||||
return root
|
||||
|
@ -23,6 +23,7 @@ import tempfile
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.lib.io import file_io
|
||||
@ -47,6 +48,7 @@ class LoadTest(test.TestCase):
|
||||
imported = load.load(save_dir)
|
||||
self.assertIs(imported.dep_three, imported.dep_two.dep)
|
||||
self.assertIsNot(imported.dep_one, imported.dep_two)
|
||||
self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
|
||||
|
||||
def _make_asset(self, contents):
|
||||
filename = tempfile.mktemp(prefix=self.get_temp_dir())
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.saved_model import builder_impl
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import function_serialization
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
@ -511,6 +512,9 @@ def _write_object_graph(root, export_dir, asset_file_def_index):
|
||||
for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
|
||||
_write_object_proto(obj, obj_proto, asset_file_def_index)
|
||||
|
||||
function_serialization.add_polymorphic_functions_to_object_graph_proto(
|
||||
checkpointable_objects, proto)
|
||||
|
||||
extra_asset_dir = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
|
||||
|
@ -48,6 +48,7 @@ message SavedObject {
|
||||
oneof kind {
|
||||
SavedUserObject user_object = 4;
|
||||
SavedAsset asset = 5;
|
||||
SavedPolymorphicFunction function = 6;
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,3 +72,13 @@ message SavedAsset {
|
||||
// `AssetFileDef.tensor_info`, MUST be ignored.
|
||||
uint32 asset_file_def_index = 1;
|
||||
}
|
||||
|
||||
// A function with multiple signatures, possibly with non-Tensor arguments.
|
||||
message SavedPolymorphicFunction {
|
||||
repeated SavedMonomorphicFunction monomorphic_function = 1;
|
||||
}
|
||||
|
||||
message SavedMonomorphicFunction {
|
||||
// A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary
|
||||
string concrete_function = 1;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user