Skeleton for PolymorphicFunction serialization

Missing things like variables, function/argument names, support for arguments that aren't a flat list of Tensors, and many other things. But it does manage to save, restore, and call a function.

Starts saving a bit of extra metadata when a new function trace is created. Since this does not have to be computed each time the function is called, I expect the performance impact to be minimal.

PiperOrigin-RevId: 225072712
This commit is contained in:
Allen Lavoie 2018-12-11 14:36:14 -08:00 committed by TensorFlower Gardener
parent 40345bd2c3
commit 2a87c2df92
13 changed files with 263 additions and 21 deletions

View File

@ -242,6 +242,7 @@ class PolymorphicFunction(object):
raise NotImplementedError() raise NotImplementedError()
self._created_variables = None self._created_variables = None
self._stateful_fn = None self._stateful_fn = None
self._stateless_fn = None
self._descriptor_cache = weakref.WeakKeyDictionary() self._descriptor_cache = weakref.WeakKeyDictionary()
self._name = name self._name = name
@ -382,6 +383,26 @@ class PolymorphicFunction(object):
return initialize_variables.get_concrete_function() return initialize_variables.get_concrete_function()
@property
def _cached_input_signatures(self):
"""All input signatures used to call this PolymorphicFunction."""
seen = set()
# Preserves signature ordering rather than returning a set() so that we
# don't need to re-sort signatures later to work around Python 2's set
# nondeterminism.
# pylint: disable=protected-access
concrete_functions = []
if self._stateful_fn:
concrete_functions.extend(self._stateful_fn._function_cache.values())
if self._stateless_fn:
concrete_functions.extend(self._stateless_fn._function_cache.values())
for concrete_function in concrete_functions:
signature = concrete_function._python_call_signature
if signature not in seen:
yield signature
seen.add(signature)
# pylint: enable=protected-access
def get_concrete_function(self, *args, **kwargs): def get_concrete_function(self, *args, **kwargs):
"""Returns a `Function` object specialized to inputs and execution context. """Returns a `Function` object specialized to inputs and execution context.

View File

@ -238,6 +238,25 @@ class DefFunctionTest(test.TestCase):
concrete = compute.get_concrete_function( concrete = compute.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32)) tensor_spec.TensorSpec(None, dtypes.float32))
self.assertAllClose(4., concrete(constant_op.constant(2.))) self.assertAllClose(4., concrete(constant_op.constant(2.)))
input_signature, = compute._cached_input_signatures
self.assertEqual(
tuple(input_signature),
(tensor_spec.TensorSpec(None, dtypes.float32),))
def test_serialization_signature_cache(self):
@def_function.function
def f(x, y):
return x, y
f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))
self.assertEqual(
set(f._cached_input_signatures),
set(((tensor_spec.TensorSpec([1, 2], dtypes.float32),
tensor_spec.TensorSpec([1], dtypes.float32)),
(tensor_spec.TensorSpec([1, 3], dtypes.int32),
tensor_spec.TensorSpec([1], dtypes.int32)))))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -748,6 +748,19 @@ class Function(object):
return ret return ret
class UnknownArgument(object):
"""Signifies an argument which is not currently handled."""
pass
def _encode_arg_for_serialization(arg):
"""A representation for this argument, for serializing signatures."""
if isinstance(arg, ops.Tensor):
return tensor_spec.TensorSpec(arg.shape, arg.dtype)
else:
return UnknownArgument()
pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
@ -1163,6 +1176,14 @@ class PolymorphicFunction(object):
autograph=self._autograph, autograph=self._autograph,
arg_names=arg_names), arg_names=arg_names),
self._function_attributes) self._function_attributes)
if self._input_signature:
python_call_signature = self._input_signature
else:
python_call_signature = tuple(
_encode_arg_for_serialization(arg) for arg in args)
# Save information about non-Tensor arguments with the concrete
# function. Used to serialize PolymorphicFunctions.
graph_function._python_call_signature = python_call_signature # pylint: disable=protected-access
self._function_cache[cache_key] = graph_function self._function_cache[cache_key] = graph_function
return graph_function, args, kwargs return graph_function, args, kwargs

View File

@ -993,17 +993,18 @@ def _call(sig, *inputs, **kwargs):
name = kwargs.pop("name", None) name = kwargs.pop("name", None)
g = ops.get_default_graph() g = ops.get_default_graph()
func_name = sig.name func_name = sig.name
if name is None:
name = func_name
attrs = _parse_kwargs_as_attrs(func_name, **kwargs) attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
output_types = [dtypes.DType(x.type) for x in sig.output_arg] output_types = [dtypes.DType(x.type) for x in sig.output_arg]
with ops.name_scope(name, func_name, inputs) as name: op = g.create_op(
op = g.create_op( func_name,
func_name, list(inputs),
list(inputs), output_types,
output_types, name=name,
name=name, attrs=attrs,
attrs=attrs, op_def=sig,
op_def=sig, compute_shapes=False)
compute_shapes=False)
if op.outputs: if op.outputs:
if len(op.outputs) == 1: if len(op.outputs) == 1:
ret = op.outputs[0] ret = op.outputs[0]
@ -1046,12 +1047,13 @@ def _from_definition(fdef, grad_func=None):
c_func = c_api.TF_FunctionImportFunctionDef(serialized) c_func = c_api.TF_FunctionImportFunctionDef(serialized)
result._c_func = c_api_util.ScopedTFFunction(c_func) result._c_func = c_api_util.ScopedTFFunction(c_func)
result._extra_inputs = [] result._extra_inputs = []
result._op_def = fdef.signature
# pylint: enable=protected-access # pylint: enable=protected-access
return result return result
def _from_library(lib): def from_library(lib):
"""Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
This method handles assigning the correct gradient functions to each This method handles assigning the correct gradient functions to each

View File

@ -1287,7 +1287,7 @@ class FunctionsFromProtos(test.TestCase):
gradients_impl.gradients([f1, f2, f3, f4], c) gradients_impl.gradients([f1, f2, f3, f4], c)
library = g.as_graph_def().library library = g.as_graph_def().library
new_funcs = function._from_library(library) new_funcs = function.from_library(library)
def CheckNewFunc(func): def CheckNewFunc(func):
new_func = [f for f in new_funcs if f.name == func.name] new_func = [f for f in new_funcs if f.name == func.name]
@ -1303,7 +1303,7 @@ class FunctionsFromProtos(test.TestCase):
def testFromLibraryEmptyLib(self): def testFromLibraryEmptyLib(self):
library = function_pb2.FunctionDefLibrary() library = function_pb2.FunctionDefLibrary()
self.assertEqual(len(function._from_library(library)), 0) self.assertEqual(len(function.from_library(library)), 0)
def testFromLibraryMissingFuncDef(self): def testFromLibraryMissingFuncDef(self):
@ -1327,7 +1327,7 @@ class FunctionsFromProtos(test.TestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
"FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"): "FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
function._from_library(library) function.from_library(library)
# Create invalid function def that is missing F1 function def # Create invalid function def that is missing F1 function def
library = function_pb2.FunctionDefLibrary() library = function_pb2.FunctionDefLibrary()
@ -1337,7 +1337,7 @@ class FunctionsFromProtos(test.TestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
"FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"): "FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
function._from_library(library) function.from_library(library)
def testFromLibraryCyclicGradFuncs(self): def testFromLibraryCyclicGradFuncs(self):
@ -1366,7 +1366,7 @@ class FunctionsFromProtos(test.TestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"): ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function._from_library(library) function.from_library(library)
def testExperimentalAttrs(self): def testExperimentalAttrs(self):

View File

@ -442,11 +442,9 @@ def import_graph_def(graph_def,
_ProcessNewOps(graph) _ProcessNewOps(graph)
if graph_def.library and graph_def.library.function: if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access functions = function.from_library(graph_def.library)
functions = function._from_library(graph_def.library)
for f in functions: for f in functions:
f.add_to_graph(graph) f.add_to_graph(graph)
# pylint: enable=protected-access
# Treat input mappings that don't appear in the graph as an error, because # Treat input mappings that don't appear in the graph as an error, because
# they are likely to be due to a typo. # they are likely to be due to a typo.

View File

@ -287,7 +287,7 @@ py_library(
deps = [ deps = [
":builder", ":builder",
":constants", ":constants",
":loader", ":function_serialization",
":saved_object_graph_py", ":saved_object_graph_py",
":signature_constants", ":signature_constants",
":signature_def_utils", ":signature_def_utils",
@ -295,15 +295,20 @@ py_library(
":utils", ":utils",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:lib", "//tensorflow/python:lib",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function", "//tensorflow/python/eager:function",
"//tensorflow/python/training/checkpointable:base", "//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/checkpointable:util", "//tensorflow/python/training/checkpointable:util",
], ],
) )
@ -330,8 +335,12 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":constants",
":function_deserialization",
":loader", ":loader",
":saved_object_graph_py", ":saved_object_graph_py",
":utils",
"//tensorflow/python:function",
"//tensorflow/python:lib", "//tensorflow/python:lib",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/training/checkpointable:tracking", "//tensorflow/python/training/checkpointable:tracking",
@ -345,10 +354,34 @@ py_test(
deps = [ deps = [
":load", ":load",
":save", ":save",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:tensor_spec", "//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:tracking", "//tensorflow/python/training/checkpointable:tracking",
], ],
) )
py_library(
name = "function_serialization",
srcs = [
"function_serialization.py",
],
srcs_version = "PY2AND3",
deps = [
":saved_object_graph_py",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
],
)
py_library(
name = "function_deserialization",
srcs = [
"function_deserialization.py",
],
srcs_version = "PY2AND3",
deps = ["//tensorflow/python/eager:def_function"],
)

View File

@ -0,0 +1,46 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for deserializing PolymorphicFunctions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import def_function
def recreate_polymorphic_function(
saved_polymorphic_function, defined_functions):
"""Creates a PolymorphicFunction which runs restored function definitions."""
@def_function.function
def restored_function(*args):
"""Calls a restored function."""
# Try calling each function, return a value from the first one whose
# signature matches.
# TODO(allenl): Consider re-populating the function cache directly.
# TODO(allenl): Functions saved with input_signatures should revive with
# input_signatures.
for monomorphic_function in saved_polymorphic_function.monomorphic_function:
try:
# TODO(allenl): Passing an explicit name here prevents invalid name
# errors. We should replace this with something based on the actual
# Python function name.
return defined_functions[monomorphic_function.concrete_function](
*args, name="imported_function")
except ValueError:
continue
raise AssertionError(
"Could not find matching function to call for arguments: %s" % (args,))
return restored_function

View File

@ -0,0 +1,71 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for serializing PolymorphicFunctions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun_lib
from tensorflow.python.saved_model import saved_object_graph_pb2
def _serialize_polymorphic_function(function):
"""Represents a PolymorphicFunction in a SavedModel.
Adds `function`'s concrete functions to the current graph.
Args:
function: A `PolymorphicFunction` to serialize.
Returns:
An unserialized `SavedPolymorphicFunction` protocol buffer object.
"""
monomorphic_functions = []
for signature in function._cached_input_signatures: # pylint: disable=protected-access
if any(isinstance(arg, defun_lib.UnknownArgument) for arg in signature):
continue
concrete_function = function.get_concrete_function(*signature)
concrete_function.add_to_graph()
monomorphic_functions.append(
saved_object_graph_pb2.SavedMonomorphicFunction(
concrete_function=concrete_function.name))
return saved_object_graph_pb2.SavedPolymorphicFunction(
monomorphic_function=monomorphic_functions)
def add_polymorphic_functions_to_object_graph_proto(
checkpointable_objects, saved_object_graph):
"""Finds PolymorphicFunctions attached to objects and saves them."""
existing_objects = list(zip(checkpointable_objects, saved_object_graph.nodes))
for obj, obj_proto in existing_objects:
for attribute_name in dir(obj):
try:
attribute_value = getattr(obj, attribute_name, None)
except: # pylint: disable=bare-except
# We really don't want to throw an exception just because some object's
# attribute accessor is broken.
attribute_value = None
# TODO(allenl): Consider de-duplicating functions which are referenced
# from multiple attributes.
if isinstance(attribute_value, def_function.PolymorphicFunction):
function_node_id = len(saved_object_graph.nodes)
function_node = saved_object_graph.nodes.add()
function_node.function.CopyFrom(
_serialize_polymorphic_function(attribute_value))
reference = obj_proto.children.add()
reference.node_id = function_node_id
reference.local_name = attribute_name

View File

@ -20,8 +20,10 @@ from __future__ import print_function
import os import os
from tensorflow.python.framework import function as function_lib
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.saved_model import utils_impl as saved_model_utils
@ -33,9 +35,17 @@ class _Loader(object):
"""Helper class to load an object-based SavedModel.""" """Helper class to load an object-based SavedModel."""
def __init__(self, object_graph_proto, saved_model_proto, export_dir): def __init__(self, object_graph_proto, saved_model_proto, export_dir):
self._asset_file_def = saved_model_proto.meta_graphs[0].asset_file_def meta_graph = saved_model_proto.meta_graphs[0]
self._asset_file_def = meta_graph.asset_file_def
self._proto = object_graph_proto self._proto = object_graph_proto
self._export_dir = export_dir self._export_dir = export_dir
self._defined_functions = {}
for defined_function in function_lib.from_library(
meta_graph.graph_def.library):
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
# happens when loaded names collide with existing names.
defined_function.add_to_graph(None)
self._defined_functions[defined_function.name] = defined_function
self._load_all() self._load_all()
def _load_all(self): def _load_all(self):
@ -52,6 +62,7 @@ class _Loader(object):
factory = { factory = {
"user_object": lambda: self._recreate_user_object(proto.user_object), "user_object": lambda: self._recreate_user_object(proto.user_object),
"asset": lambda: self._recreate_asset(proto.asset), "asset": lambda: self._recreate_asset(proto.asset),
"function": lambda: self._recreate_function(proto.function)
} }
kind = proto.WhichOneof("kind") kind = proto.WhichOneof("kind")
if kind not in factory: if kind not in factory:
@ -68,6 +79,10 @@ class _Loader(object):
self._asset_file_def[proto.asset_file_def_index].filename) self._asset_file_def[proto.asset_file_def_index].filename)
return tracking.TrackableAsset(filename) return tracking.TrackableAsset(filename)
def _recreate_function(self, proto):
return function_deserialization.recreate_polymorphic_function(
proto, self._defined_functions)
def _load_saved_object_graph_proto(filename): def _load_saved_object_graph_proto(filename):
with file_io.FileIO(filename, "rb") as f: with file_io.FileIO(filename, "rb") as f:
@ -92,5 +107,4 @@ def load(export_dir):
raise NotImplementedError( raise NotImplementedError(
"Currently only SavedModels exported with `tf.saved_model.save` may be " "Currently only SavedModels exported with `tf.saved_model.save` may be "
"imported. Other SavedModels may eventually be supported via load().") "imported. Other SavedModels may eventually be supported via load().")
# TODO(allenl): load functions from the SavedModel into the eager context
return root return root

View File

@ -23,6 +23,7 @@ import tempfile
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
@ -47,6 +48,7 @@ class LoadTest(test.TestCase):
imported = load.load(save_dir) imported = load.load(save_dir)
self.assertIs(imported.dep_three, imported.dep_two.dep) self.assertIs(imported.dep_three, imported.dep_two.dep)
self.assertIsNot(imported.dep_one, imported.dep_two) self.assertIsNot(imported.dep_one, imported.dep_two)
self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
def _make_asset(self, contents): def _make_asset(self, contents):
filename = tempfile.mktemp(prefix=self.get_temp_dir()) filename = tempfile.mktemp(prefix=self.get_temp_dir())

View File

@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import signature_def_utils
@ -511,6 +512,9 @@ def _write_object_graph(root, export_dir, asset_file_def_index):
for obj, obj_proto in zip(checkpointable_objects, proto.nodes): for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index) _write_object_proto(obj, obj_proto, asset_file_def_index)
function_serialization.add_polymorphic_functions_to_object_graph_proto(
checkpointable_objects, proto)
extra_asset_dir = os.path.join( extra_asset_dir = os.path.join(
compat.as_bytes(export_dir), compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))

View File

@ -48,6 +48,7 @@ message SavedObject {
oneof kind { oneof kind {
SavedUserObject user_object = 4; SavedUserObject user_object = 4;
SavedAsset asset = 5; SavedAsset asset = 5;
SavedPolymorphicFunction function = 6;
} }
} }
@ -71,3 +72,13 @@ message SavedAsset {
// `AssetFileDef.tensor_info`, MUST be ignored. // `AssetFileDef.tensor_info`, MUST be ignored.
uint32 asset_file_def_index = 1; uint32 asset_file_def_index = 1;
} }
// A function with multiple signatures, possibly with non-Tensor arguments.
message SavedPolymorphicFunction {
repeated SavedMonomorphicFunction monomorphic_function = 1;
}
message SavedMonomorphicFunction {
// A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary
string concrete_function = 1;
}