From 2a87c2df921753fb8c1cba585f78bd3ab6087be2 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 11 Dec 2018 14:36:14 -0800 Subject: [PATCH] Skeleton for PolymorphicFunction serialization Missing things like variables, function/argument names, support for arguments that aren't a flat list of Tensors, and many other things. But it does manage to save, restore, and call a function. Starts saving a bit of extra metadata when a new function trace is created. Since this does not have to be computed each time the function is called, I expect the performance impact to be minimal. PiperOrigin-RevId: 225072712 --- tensorflow/python/eager/def_function.py | 21 ++++++ tensorflow/python/eager/def_function_test.py | 19 +++++ tensorflow/python/eager/function.py | 21 ++++++ tensorflow/python/framework/function.py | 22 +++--- tensorflow/python/framework/function_test.py | 10 +-- tensorflow/python/framework/importer.py | 4 +- tensorflow/python/saved_model/BUILD | 35 ++++++++- .../saved_model/function_deserialization.py | 46 ++++++++++++ .../saved_model/function_serialization.py | 71 +++++++++++++++++++ tensorflow/python/saved_model/load.py | 18 ++++- tensorflow/python/saved_model/load_test.py | 2 + tensorflow/python/saved_model/save.py | 4 ++ .../saved_model/saved_object_graph.proto | 11 +++ 13 files changed, 263 insertions(+), 21 deletions(-) create mode 100644 tensorflow/python/saved_model/function_deserialization.py create mode 100644 tensorflow/python/saved_model/function_serialization.py diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 3663d729996..cdbf39ddd57 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -242,6 +242,7 @@ class PolymorphicFunction(object): raise NotImplementedError() self._created_variables = None self._stateful_fn = None + self._stateless_fn = None self._descriptor_cache = weakref.WeakKeyDictionary() self._name = name @@ -382,6 +383,26 @@ class PolymorphicFunction(object): return initialize_variables.get_concrete_function() + @property + def _cached_input_signatures(self): + """All input signatures used to call this PolymorphicFunction.""" + seen = set() + # Preserves signature ordering rather than returning a set() so that we + # don't need to re-sort signatures later to work around Python 2's set + # nondeterminism. + # pylint: disable=protected-access + concrete_functions = [] + if self._stateful_fn: + concrete_functions.extend(self._stateful_fn._function_cache.values()) + if self._stateless_fn: + concrete_functions.extend(self._stateless_fn._function_cache.values()) + for concrete_function in concrete_functions: + signature = concrete_function._python_call_signature + if signature not in seen: + yield signature + seen.add(signature) + # pylint: enable=protected-access + def get_concrete_function(self, *args, **kwargs): """Returns a `Function` object specialized to inputs and execution context. diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 4100a10044c..8b4c40791a7 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -238,6 +238,25 @@ class DefFunctionTest(test.TestCase): concrete = compute.get_concrete_function( tensor_spec.TensorSpec(None, dtypes.float32)) self.assertAllClose(4., concrete(constant_op.constant(2.))) + input_signature, = compute._cached_input_signatures + self.assertEqual( + tuple(input_signature), + (tensor_spec.TensorSpec(None, dtypes.float32),)) + + def test_serialization_signature_cache(self): + + @def_function.function + def f(x, y): + return x, y + + f(constant_op.constant([[3., 4.]]), constant_op.constant([2.])) + f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2])) + self.assertEqual( + set(f._cached_input_signatures), + set(((tensor_spec.TensorSpec([1, 2], dtypes.float32), + tensor_spec.TensorSpec([1], dtypes.float32)), + (tensor_spec.TensorSpec([1, 3], dtypes.int32), + tensor_spec.TensorSpec([1], dtypes.int32))))) if __name__ == '__main__': diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 520c85a2c20..0de0cd96ac4 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -748,6 +748,19 @@ class Function(object): return ret +class UnknownArgument(object): + """Signifies an argument which is not currently handled.""" + pass + + +def _encode_arg_for_serialization(arg): + """A representation for this argument, for serializing signatures.""" + if isinstance(arg, ops.Tensor): + return tensor_spec.TensorSpec(arg.shape, arg.dtype) + else: + return UnknownArgument() + + pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) @@ -1163,6 +1176,14 @@ class PolymorphicFunction(object): autograph=self._autograph, arg_names=arg_names), self._function_attributes) + if self._input_signature: + python_call_signature = self._input_signature + else: + python_call_signature = tuple( + _encode_arg_for_serialization(arg) for arg in args) + # Save information about non-Tensor arguments with the concrete + # function. Used to serialize PolymorphicFunctions. + graph_function._python_call_signature = python_call_signature # pylint: disable=protected-access self._function_cache[cache_key] = graph_function return graph_function, args, kwargs diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index cfdc915a1b3..afc11b17bfd 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -993,17 +993,18 @@ def _call(sig, *inputs, **kwargs): name = kwargs.pop("name", None) g = ops.get_default_graph() func_name = sig.name + if name is None: + name = func_name attrs = _parse_kwargs_as_attrs(func_name, **kwargs) output_types = [dtypes.DType(x.type) for x in sig.output_arg] - with ops.name_scope(name, func_name, inputs) as name: - op = g.create_op( - func_name, - list(inputs), - output_types, - name=name, - attrs=attrs, - op_def=sig, - compute_shapes=False) + op = g.create_op( + func_name, + list(inputs), + output_types, + name=name, + attrs=attrs, + op_def=sig, + compute_shapes=False) if op.outputs: if len(op.outputs) == 1: ret = op.outputs[0] @@ -1046,12 +1047,13 @@ def _from_definition(fdef, grad_func=None): c_func = c_api.TF_FunctionImportFunctionDef(serialized) result._c_func = c_api_util.ScopedTFFunction(c_func) result._extra_inputs = [] + result._op_def = fdef.signature # pylint: enable=protected-access return result -def _from_library(lib): +def from_library(lib): """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. This method handles assigning the correct gradient functions to each diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 6ec71ba8e90..7543376bcf2 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1287,7 +1287,7 @@ class FunctionsFromProtos(test.TestCase): gradients_impl.gradients([f1, f2, f3, f4], c) library = g.as_graph_def().library - new_funcs = function._from_library(library) + new_funcs = function.from_library(library) def CheckNewFunc(func): new_func = [f for f in new_funcs if f.name == func.name] @@ -1303,7 +1303,7 @@ class FunctionsFromProtos(test.TestCase): def testFromLibraryEmptyLib(self): library = function_pb2.FunctionDefLibrary() - self.assertEqual(len(function._from_library(library)), 0) + self.assertEqual(len(function.from_library(library)), 0) def testFromLibraryMissingFuncDef(self): @@ -1327,7 +1327,7 @@ class FunctionsFromProtos(test.TestCase): with self.assertRaisesRegexp( ValueError, "FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"): - function._from_library(library) + function.from_library(library) # Create invalid function def that is missing F1 function def library = function_pb2.FunctionDefLibrary() @@ -1337,7 +1337,7 @@ class FunctionsFromProtos(test.TestCase): with self.assertRaisesRegexp( ValueError, "FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"): - function._from_library(library) + function.from_library(library) def testFromLibraryCyclicGradFuncs(self): @@ -1366,7 +1366,7 @@ class FunctionsFromProtos(test.TestCase): with self.assertRaisesRegexp( ValueError, "FunctionDefLibrary contains cyclic gradient functions!"): - function._from_library(library) + function.from_library(library) def testExperimentalAttrs(self): diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 98c7aeccc4b..c737bd48811 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -442,11 +442,9 @@ def import_graph_def(graph_def, _ProcessNewOps(graph) if graph_def.library and graph_def.library.function: - # pylint: disable=protected-access - functions = function._from_library(graph_def.library) + functions = function.from_library(graph_def.library) for f in functions: f.add_to_graph(graph) - # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 53d0640542f..71d9e34592b 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -287,7 +287,7 @@ py_library( deps = [ ":builder", ":constants", - ":loader", + ":function_serialization", ":saved_object_graph_py", ":signature_constants", ":signature_def_utils", @@ -295,15 +295,20 @@ py_library( ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tensor_spec", "//tensorflow/python:util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:tracking", "//tensorflow/python/training/checkpointable:util", ], ) @@ -330,8 +335,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":constants", + ":function_deserialization", ":loader", ":saved_object_graph_py", + ":utils", + "//tensorflow/python:function", "//tensorflow/python:lib", "//tensorflow/python:util", "//tensorflow/python/training/checkpointable:tracking", @@ -345,10 +354,34 @@ py_test( deps = [ ":load", ":save", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:lib", "//tensorflow/python:tensor_spec", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", "//tensorflow/python/training/checkpointable:tracking", ], ) + +py_library( + name = "function_serialization", + srcs = [ + "function_serialization.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":saved_object_graph_py", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:function", + ], +) + +py_library( + name = "function_deserialization", + srcs = [ + "function_deserialization.py", + ], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python/eager:def_function"], +) diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py new file mode 100644 index 00000000000..46bd69ad031 --- /dev/null +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for deserializing PolymorphicFunctions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import def_function + + +def recreate_polymorphic_function( + saved_polymorphic_function, defined_functions): + """Creates a PolymorphicFunction which runs restored function definitions.""" + @def_function.function + def restored_function(*args): + """Calls a restored function.""" + # Try calling each function, return a value from the first one whose + # signature matches. + # TODO(allenl): Consider re-populating the function cache directly. + # TODO(allenl): Functions saved with input_signatures should revive with + # input_signatures. + for monomorphic_function in saved_polymorphic_function.monomorphic_function: + try: + # TODO(allenl): Passing an explicit name here prevents invalid name + # errors. We should replace this with something based on the actual + # Python function name. + return defined_functions[monomorphic_function.concrete_function]( + *args, name="imported_function") + except ValueError: + continue + raise AssertionError( + "Could not find matching function to call for arguments: %s" % (args,)) + return restored_function diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py new file mode 100644 index 00000000000..7cf82776bdf --- /dev/null +++ b/tensorflow/python/saved_model/function_serialization.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for serializing PolymorphicFunctions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import def_function +from tensorflow.python.eager import function as defun_lib +from tensorflow.python.saved_model import saved_object_graph_pb2 + + +def _serialize_polymorphic_function(function): + """Represents a PolymorphicFunction in a SavedModel. + + Adds `function`'s concrete functions to the current graph. + + Args: + function: A `PolymorphicFunction` to serialize. + + Returns: + An unserialized `SavedPolymorphicFunction` protocol buffer object. + """ + monomorphic_functions = [] + for signature in function._cached_input_signatures: # pylint: disable=protected-access + if any(isinstance(arg, defun_lib.UnknownArgument) for arg in signature): + continue + concrete_function = function.get_concrete_function(*signature) + concrete_function.add_to_graph() + monomorphic_functions.append( + saved_object_graph_pb2.SavedMonomorphicFunction( + concrete_function=concrete_function.name)) + return saved_object_graph_pb2.SavedPolymorphicFunction( + monomorphic_function=monomorphic_functions) + + +def add_polymorphic_functions_to_object_graph_proto( + checkpointable_objects, saved_object_graph): + """Finds PolymorphicFunctions attached to objects and saves them.""" + existing_objects = list(zip(checkpointable_objects, saved_object_graph.nodes)) + for obj, obj_proto in existing_objects: + for attribute_name in dir(obj): + try: + attribute_value = getattr(obj, attribute_name, None) + except: # pylint: disable=bare-except + # We really don't want to throw an exception just because some object's + # attribute accessor is broken. + attribute_value = None + # TODO(allenl): Consider de-duplicating functions which are referenced + # from multiple attributes. + if isinstance(attribute_value, def_function.PolymorphicFunction): + function_node_id = len(saved_object_graph.nodes) + function_node = saved_object_graph.nodes.add() + function_node.function.CopyFrom( + _serialize_polymorphic_function(attribute_value)) + reference = obj_proto.children.add() + reference.node_id = function_node_id + reference.local_name = attribute_name diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index e3095f4ee5e..28c0af2b657 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -20,8 +20,10 @@ from __future__ import print_function import os +from tensorflow.python.framework import function as function_lib from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import function_deserialization from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.python.saved_model import utils_impl as saved_model_utils @@ -33,9 +35,17 @@ class _Loader(object): """Helper class to load an object-based SavedModel.""" def __init__(self, object_graph_proto, saved_model_proto, export_dir): - self._asset_file_def = saved_model_proto.meta_graphs[0].asset_file_def + meta_graph = saved_model_proto.meta_graphs[0] + self._asset_file_def = meta_graph.asset_file_def self._proto = object_graph_proto self._export_dir = export_dir + self._defined_functions = {} + for defined_function in function_lib.from_library( + meta_graph.graph_def.library): + # TODO(allenl): Do we need to do name mapping here? Not quite sure what + # happens when loaded names collide with existing names. + defined_function.add_to_graph(None) + self._defined_functions[defined_function.name] = defined_function self._load_all() def _load_all(self): @@ -52,6 +62,7 @@ class _Loader(object): factory = { "user_object": lambda: self._recreate_user_object(proto.user_object), "asset": lambda: self._recreate_asset(proto.asset), + "function": lambda: self._recreate_function(proto.function) } kind = proto.WhichOneof("kind") if kind not in factory: @@ -68,6 +79,10 @@ class _Loader(object): self._asset_file_def[proto.asset_file_def_index].filename) return tracking.TrackableAsset(filename) + def _recreate_function(self, proto): + return function_deserialization.recreate_polymorphic_function( + proto, self._defined_functions) + def _load_saved_object_graph_proto(filename): with file_io.FileIO(filename, "rb") as f: @@ -92,5 +107,4 @@ def load(export_dir): raise NotImplementedError( "Currently only SavedModels exported with `tf.saved_model.save` may be " "imported. Other SavedModels may eventually be supported via load().") - # TODO(allenl): load functions from the SavedModel into the eager context return root diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index a2971101cdb..6a10ac432d5 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -23,6 +23,7 @@ import tempfile from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_spec from tensorflow.python.lib.io import file_io @@ -47,6 +48,7 @@ class LoadTest(test.TestCase): imported = load.load(save_dir) self.assertIs(imported.dep_three, imported.dep_two.dep) self.assertIsNot(imported.dep_one, imported.dep_two) + self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) def _make_asset(self, contents): filename = tempfile.mktemp(prefix=self.get_temp_dir()) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index e2726087a5c..b065a5a265b 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import function_serialization from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils @@ -511,6 +512,9 @@ def _write_object_graph(root, export_dir, asset_file_def_index): for obj, obj_proto in zip(checkpointable_objects, proto.nodes): _write_object_proto(obj, obj_proto, asset_file_def_index) + function_serialization.add_polymorphic_functions_to_object_graph_proto( + checkpointable_objects, proto) + extra_asset_dir = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY)) diff --git a/tensorflow/python/saved_model/saved_object_graph.proto b/tensorflow/python/saved_model/saved_object_graph.proto index 3991fbede42..ed5c63935ff 100644 --- a/tensorflow/python/saved_model/saved_object_graph.proto +++ b/tensorflow/python/saved_model/saved_object_graph.proto @@ -48,6 +48,7 @@ message SavedObject { oneof kind { SavedUserObject user_object = 4; SavedAsset asset = 5; + SavedPolymorphicFunction function = 6; } } @@ -71,3 +72,13 @@ message SavedAsset { // `AssetFileDef.tensor_info`, MUST be ignored. uint32 asset_file_def_index = 1; } + +// A function with multiple signatures, possibly with non-Tensor arguments. +message SavedPolymorphicFunction { + repeated SavedMonomorphicFunction monomorphic_function = 1; +} + +message SavedMonomorphicFunction { + // A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary + string concrete_function = 1; +}