From 6decf0842b1f7ec17c7d8957d453cd5132b7a128 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Dec 2018 02:55:37 -0800 Subject: [PATCH] Refactor deserialization of functions in object-based SavedModel. Restore a FunctionDef into a FuncGraph and wrap it into Function object. Replace calling of all functions until something fits with a input signature match code structure. PiperOrigin-RevId: 225796290 --- .../saved_model/function_deserialization.py | 48 ++++++++++++++----- tensorflow/python/saved_model/load.py | 23 +++++---- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 46bd69ad031..51e23574ca6 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -19,28 +19,50 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import def_function +from tensorflow.python.util import nest + + +def _inputs_compatible(args, function): + # TODO(vbardiovsky): The compatibility check should be about the signature, + # not the flattened version of it. + flattened_inputs = nest.flatten(args) + if len(flattened_inputs) != len(function.inputs): + return False + for a, b in zip(flattened_inputs, function.inputs): + if a.dtype != b.dtype or not b.shape.is_compatible_with(a.shape): + return False + return True def recreate_polymorphic_function( - saved_polymorphic_function, defined_functions): - """Creates a PolymorphicFunction which runs restored function definitions.""" + saved_polymorphic_function, functions): + """Creates a PolymorphicFunction from a SavedPolymorphicFunction. + + Args: + saved_polymorphic_function: SavedPolymorphicFunction proto. + functions: map from function name to Function. + + Returns: + A PolymorphicFunction. + """ + # TODO(andresp): Construct a PolymorphicFunction with the cache populated + # instead of creating a new PolymorphicFunction backed by a Python layer to + # glue things together. Current approach is nesting functions deeper for each + # serialization cycle. @def_function.function def restored_function(*args): """Calls a restored function.""" - # Try calling each function, return a value from the first one whose - # signature matches. - # TODO(allenl): Consider re-populating the function cache directly. # TODO(allenl): Functions saved with input_signatures should revive with # input_signatures. for monomorphic_function in saved_polymorphic_function.monomorphic_function: - try: - # TODO(allenl): Passing an explicit name here prevents invalid name - # errors. We should replace this with something based on the actual - # Python function name. - return defined_functions[monomorphic_function.concrete_function]( - *args, name="imported_function") - except ValueError: - continue + function_obj = functions[monomorphic_function.concrete_function] + if _inputs_compatible(args, function_obj): + flattened_inputs = nest.flatten(args) + flattened_outputs = function_obj._call_flat(flattened_inputs) # pylint: disable=protected-access + # TODO(vbardiovsky): rebuild output structure. + single_output, = flattened_outputs + return single_output + raise AssertionError( "Could not find matching function to call for arguments: %s" % (args,)) return restored_function diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 9d9f60c69dd..3ebc08caef6 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -20,7 +20,8 @@ from __future__ import print_function import os -from tensorflow.python.framework import function as function_lib +from tensorflow.python.eager import function +from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.lib.io import file_io from tensorflow.python.ops import init_ops from tensorflow.python.ops import variables @@ -42,16 +43,20 @@ class _Loader(object): self._asset_file_def = meta_graph.asset_file_def self._proto = object_graph_proto self._export_dir = export_dir - self._defined_functions = {} - for defined_function in function_lib.from_library( - meta_graph.graph_def.library): - # TODO(allenl): Do we need to do name mapping here? Not quite sure what - # happens when loaded names collide with existing names. - defined_function.add_to_graph(None) - self._defined_functions[defined_function.name] = defined_function + self._load_func_graphs(meta_graph.graph_def.library) self._load_all() self._restore_checkpoint() + def _load_func_graphs(self, function_library): + # TODO(allenl): Do we need to do name mapping here? Not quite sure what + # happens when loaded names collide with existing names. + # TODO(andresp): Look into gradient functions and the need to restore + # functions in the right order. + self._functions = {} + for fdef in function_library.function: + self._functions[fdef.signature.name] = function.Function( + function_def_lib.function_def_to_graph(fdef)) + def _load_all(self): self._nodes = [self._recreate(proto) for proto in self._proto.nodes] # After creating the objects, construct the edges between the objects. @@ -92,7 +97,7 @@ class _Loader(object): def _recreate_function(self, proto): return function_deserialization.recreate_polymorphic_function( - proto, self._defined_functions) + proto, self._functions) def _recreate_variable(self, proto): # TODO(andresp): Can we use the checkpointed value as initializer?