Refactor deserialization of functions in object-based SavedModel.
Restore a FunctionDef into a FuncGraph and wrap it into Function object. Replace calling of all functions until something fits with a input signature match code structure. PiperOrigin-RevId: 225796290
This commit is contained in:
parent
06303a8ea0
commit
6decf0842b
@ -19,28 +19,50 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _inputs_compatible(args, function):
|
||||
# TODO(vbardiovsky): The compatibility check should be about the signature,
|
||||
# not the flattened version of it.
|
||||
flattened_inputs = nest.flatten(args)
|
||||
if len(flattened_inputs) != len(function.inputs):
|
||||
return False
|
||||
for a, b in zip(flattened_inputs, function.inputs):
|
||||
if a.dtype != b.dtype or not b.shape.is_compatible_with(a.shape):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def recreate_polymorphic_function(
|
||||
saved_polymorphic_function, defined_functions):
|
||||
"""Creates a PolymorphicFunction which runs restored function definitions."""
|
||||
saved_polymorphic_function, functions):
|
||||
"""Creates a PolymorphicFunction from a SavedPolymorphicFunction.
|
||||
|
||||
Args:
|
||||
saved_polymorphic_function: SavedPolymorphicFunction proto.
|
||||
functions: map from function name to Function.
|
||||
|
||||
Returns:
|
||||
A PolymorphicFunction.
|
||||
"""
|
||||
# TODO(andresp): Construct a PolymorphicFunction with the cache populated
|
||||
# instead of creating a new PolymorphicFunction backed by a Python layer to
|
||||
# glue things together. Current approach is nesting functions deeper for each
|
||||
# serialization cycle.
|
||||
@def_function.function
|
||||
def restored_function(*args):
|
||||
"""Calls a restored function."""
|
||||
# Try calling each function, return a value from the first one whose
|
||||
# signature matches.
|
||||
# TODO(allenl): Consider re-populating the function cache directly.
|
||||
# TODO(allenl): Functions saved with input_signatures should revive with
|
||||
# input_signatures.
|
||||
for monomorphic_function in saved_polymorphic_function.monomorphic_function:
|
||||
try:
|
||||
# TODO(allenl): Passing an explicit name here prevents invalid name
|
||||
# errors. We should replace this with something based on the actual
|
||||
# Python function name.
|
||||
return defined_functions[monomorphic_function.concrete_function](
|
||||
*args, name="imported_function")
|
||||
except ValueError:
|
||||
continue
|
||||
function_obj = functions[monomorphic_function.concrete_function]
|
||||
if _inputs_compatible(args, function_obj):
|
||||
flattened_inputs = nest.flatten(args)
|
||||
flattened_outputs = function_obj._call_flat(flattened_inputs) # pylint: disable=protected-access
|
||||
# TODO(vbardiovsky): rebuild output structure.
|
||||
single_output, = flattened_outputs
|
||||
return single_output
|
||||
|
||||
raise AssertionError(
|
||||
"Could not find matching function to call for arguments: %s" % (args,))
|
||||
return restored_function
|
||||
|
@ -20,7 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.framework import function as function_lib
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import function_def_to_graph as function_def_lib
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -42,16 +43,20 @@ class _Loader(object):
|
||||
self._asset_file_def = meta_graph.asset_file_def
|
||||
self._proto = object_graph_proto
|
||||
self._export_dir = export_dir
|
||||
self._defined_functions = {}
|
||||
for defined_function in function_lib.from_library(
|
||||
meta_graph.graph_def.library):
|
||||
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
|
||||
# happens when loaded names collide with existing names.
|
||||
defined_function.add_to_graph(None)
|
||||
self._defined_functions[defined_function.name] = defined_function
|
||||
self._load_func_graphs(meta_graph.graph_def.library)
|
||||
self._load_all()
|
||||
self._restore_checkpoint()
|
||||
|
||||
def _load_func_graphs(self, function_library):
|
||||
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
|
||||
# happens when loaded names collide with existing names.
|
||||
# TODO(andresp): Look into gradient functions and the need to restore
|
||||
# functions in the right order.
|
||||
self._functions = {}
|
||||
for fdef in function_library.function:
|
||||
self._functions[fdef.signature.name] = function.Function(
|
||||
function_def_lib.function_def_to_graph(fdef))
|
||||
|
||||
def _load_all(self):
|
||||
self._nodes = [self._recreate(proto) for proto in self._proto.nodes]
|
||||
# After creating the objects, construct the edges between the objects.
|
||||
@ -92,7 +97,7 @@ class _Loader(object):
|
||||
|
||||
def _recreate_function(self, proto):
|
||||
return function_deserialization.recreate_polymorphic_function(
|
||||
proto, self._defined_functions)
|
||||
proto, self._functions)
|
||||
|
||||
def _recreate_variable(self, proto):
|
||||
# TODO(andresp): Can we use the checkpointed value as initializer?
|
||||
|
Loading…
x
Reference in New Issue
Block a user