Refactor deserialization of functions in object-based SavedModel.

Restore a FunctionDef into a FuncGraph and wrap it into Function object.
Replace calling of all functions until something fits with a input signature
match code structure.

PiperOrigin-RevId: 225796290
This commit is contained in:
A. Unique TensorFlower 2018-12-17 02:55:37 -08:00 committed by TensorFlower Gardener
parent 06303a8ea0
commit 6decf0842b
2 changed files with 49 additions and 22 deletions

View File

@ -19,28 +19,50 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import def_function
from tensorflow.python.util import nest
def _inputs_compatible(args, function):
# TODO(vbardiovsky): The compatibility check should be about the signature,
# not the flattened version of it.
flattened_inputs = nest.flatten(args)
if len(flattened_inputs) != len(function.inputs):
return False
for a, b in zip(flattened_inputs, function.inputs):
if a.dtype != b.dtype or not b.shape.is_compatible_with(a.shape):
return False
return True
def recreate_polymorphic_function(
saved_polymorphic_function, defined_functions):
"""Creates a PolymorphicFunction which runs restored function definitions."""
saved_polymorphic_function, functions):
"""Creates a PolymorphicFunction from a SavedPolymorphicFunction.
Args:
saved_polymorphic_function: SavedPolymorphicFunction proto.
functions: map from function name to Function.
Returns:
A PolymorphicFunction.
"""
# TODO(andresp): Construct a PolymorphicFunction with the cache populated
# instead of creating a new PolymorphicFunction backed by a Python layer to
# glue things together. Current approach is nesting functions deeper for each
# serialization cycle.
@def_function.function
def restored_function(*args):
"""Calls a restored function."""
# Try calling each function, return a value from the first one whose
# signature matches.
# TODO(allenl): Consider re-populating the function cache directly.
# TODO(allenl): Functions saved with input_signatures should revive with
# input_signatures.
for monomorphic_function in saved_polymorphic_function.monomorphic_function:
try:
# TODO(allenl): Passing an explicit name here prevents invalid name
# errors. We should replace this with something based on the actual
# Python function name.
return defined_functions[monomorphic_function.concrete_function](
*args, name="imported_function")
except ValueError:
continue
function_obj = functions[monomorphic_function.concrete_function]
if _inputs_compatible(args, function_obj):
flattened_inputs = nest.flatten(args)
flattened_outputs = function_obj._call_flat(flattened_inputs) # pylint: disable=protected-access
# TODO(vbardiovsky): rebuild output structure.
single_output, = flattened_outputs
return single_output
raise AssertionError(
"Could not find matching function to call for arguments: %s" % (args,))
return restored_function

View File

@ -20,7 +20,8 @@ from __future__ import print_function
import os
from tensorflow.python.framework import function as function_lib
from tensorflow.python.eager import function
from tensorflow.python.framework import function_def_to_graph as function_def_lib
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables
@ -42,16 +43,20 @@ class _Loader(object):
self._asset_file_def = meta_graph.asset_file_def
self._proto = object_graph_proto
self._export_dir = export_dir
self._defined_functions = {}
for defined_function in function_lib.from_library(
meta_graph.graph_def.library):
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
# happens when loaded names collide with existing names.
defined_function.add_to_graph(None)
self._defined_functions[defined_function.name] = defined_function
self._load_func_graphs(meta_graph.graph_def.library)
self._load_all()
self._restore_checkpoint()
def _load_func_graphs(self, function_library):
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
# happens when loaded names collide with existing names.
# TODO(andresp): Look into gradient functions and the need to restore
# functions in the right order.
self._functions = {}
for fdef in function_library.function:
self._functions[fdef.signature.name] = function.Function(
function_def_lib.function_def_to_graph(fdef))
def _load_all(self):
self._nodes = [self._recreate(proto) for proto in self._proto.nodes]
# After creating the objects, construct the edges between the objects.
@ -92,7 +97,7 @@ class _Loader(object):
def _recreate_function(self, proto):
return function_deserialization.recreate_polymorphic_function(
proto, self._defined_functions)
proto, self._functions)
def _recreate_variable(self, proto):
# TODO(andresp): Can we use the checkpointed value as initializer?