Skeleton for PolymorphicFunction serialization
Missing things like variables, function/argument names, support for arguments that aren't a flat list of Tensors, and many other things. But it does manage to save, restore, and call a function. Starts saving a bit of extra metadata when a new function trace is created. Since this does not have to be computed each time the function is called, I expect the performance impact to be minimal. PiperOrigin-RevId: 225072712
This commit is contained in:
parent
40345bd2c3
commit
2a87c2df92
@ -242,6 +242,7 @@ class PolymorphicFunction(object):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
self._created_variables = None
|
self._created_variables = None
|
||||||
self._stateful_fn = None
|
self._stateful_fn = None
|
||||||
|
self._stateless_fn = None
|
||||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||||
self._name = name
|
self._name = name
|
||||||
|
|
||||||
@ -382,6 +383,26 @@ class PolymorphicFunction(object):
|
|||||||
|
|
||||||
return initialize_variables.get_concrete_function()
|
return initialize_variables.get_concrete_function()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cached_input_signatures(self):
|
||||||
|
"""All input signatures used to call this PolymorphicFunction."""
|
||||||
|
seen = set()
|
||||||
|
# Preserves signature ordering rather than returning a set() so that we
|
||||||
|
# don't need to re-sort signatures later to work around Python 2's set
|
||||||
|
# nondeterminism.
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
concrete_functions = []
|
||||||
|
if self._stateful_fn:
|
||||||
|
concrete_functions.extend(self._stateful_fn._function_cache.values())
|
||||||
|
if self._stateless_fn:
|
||||||
|
concrete_functions.extend(self._stateless_fn._function_cache.values())
|
||||||
|
for concrete_function in concrete_functions:
|
||||||
|
signature = concrete_function._python_call_signature
|
||||||
|
if signature not in seen:
|
||||||
|
yield signature
|
||||||
|
seen.add(signature)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
def get_concrete_function(self, *args, **kwargs):
|
def get_concrete_function(self, *args, **kwargs):
|
||||||
"""Returns a `Function` object specialized to inputs and execution context.
|
"""Returns a `Function` object specialized to inputs and execution context.
|
||||||
|
|
||||||
|
@ -238,6 +238,25 @@ class DefFunctionTest(test.TestCase):
|
|||||||
concrete = compute.get_concrete_function(
|
concrete = compute.get_concrete_function(
|
||||||
tensor_spec.TensorSpec(None, dtypes.float32))
|
tensor_spec.TensorSpec(None, dtypes.float32))
|
||||||
self.assertAllClose(4., concrete(constant_op.constant(2.)))
|
self.assertAllClose(4., concrete(constant_op.constant(2.)))
|
||||||
|
input_signature, = compute._cached_input_signatures
|
||||||
|
self.assertEqual(
|
||||||
|
tuple(input_signature),
|
||||||
|
(tensor_spec.TensorSpec(None, dtypes.float32),))
|
||||||
|
|
||||||
|
def test_serialization_signature_cache(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
|
||||||
|
f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))
|
||||||
|
self.assertEqual(
|
||||||
|
set(f._cached_input_signatures),
|
||||||
|
set(((tensor_spec.TensorSpec([1, 2], dtypes.float32),
|
||||||
|
tensor_spec.TensorSpec([1], dtypes.float32)),
|
||||||
|
(tensor_spec.TensorSpec([1, 3], dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec([1], dtypes.int32)))))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -748,6 +748,19 @@ class Function(object):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownArgument(object):
|
||||||
|
"""Signifies an argument which is not currently handled."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_arg_for_serialization(arg):
|
||||||
|
"""A representation for this argument, for serializing signatures."""
|
||||||
|
if isinstance(arg, ops.Tensor):
|
||||||
|
return tensor_spec.TensorSpec(arg.shape, arg.dtype)
|
||||||
|
else:
|
||||||
|
return UnknownArgument()
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
|
pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
|
||||||
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
|
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
|
||||||
|
|
||||||
@ -1163,6 +1176,14 @@ class PolymorphicFunction(object):
|
|||||||
autograph=self._autograph,
|
autograph=self._autograph,
|
||||||
arg_names=arg_names),
|
arg_names=arg_names),
|
||||||
self._function_attributes)
|
self._function_attributes)
|
||||||
|
if self._input_signature:
|
||||||
|
python_call_signature = self._input_signature
|
||||||
|
else:
|
||||||
|
python_call_signature = tuple(
|
||||||
|
_encode_arg_for_serialization(arg) for arg in args)
|
||||||
|
# Save information about non-Tensor arguments with the concrete
|
||||||
|
# function. Used to serialize PolymorphicFunctions.
|
||||||
|
graph_function._python_call_signature = python_call_signature # pylint: disable=protected-access
|
||||||
self._function_cache[cache_key] = graph_function
|
self._function_cache[cache_key] = graph_function
|
||||||
return graph_function, args, kwargs
|
return graph_function, args, kwargs
|
||||||
|
|
||||||
|
@ -993,9 +993,10 @@ def _call(sig, *inputs, **kwargs):
|
|||||||
name = kwargs.pop("name", None)
|
name = kwargs.pop("name", None)
|
||||||
g = ops.get_default_graph()
|
g = ops.get_default_graph()
|
||||||
func_name = sig.name
|
func_name = sig.name
|
||||||
|
if name is None:
|
||||||
|
name = func_name
|
||||||
attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
|
attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
|
||||||
output_types = [dtypes.DType(x.type) for x in sig.output_arg]
|
output_types = [dtypes.DType(x.type) for x in sig.output_arg]
|
||||||
with ops.name_scope(name, func_name, inputs) as name:
|
|
||||||
op = g.create_op(
|
op = g.create_op(
|
||||||
func_name,
|
func_name,
|
||||||
list(inputs),
|
list(inputs),
|
||||||
@ -1046,12 +1047,13 @@ def _from_definition(fdef, grad_func=None):
|
|||||||
c_func = c_api.TF_FunctionImportFunctionDef(serialized)
|
c_func = c_api.TF_FunctionImportFunctionDef(serialized)
|
||||||
result._c_func = c_api_util.ScopedTFFunction(c_func)
|
result._c_func = c_api_util.ScopedTFFunction(c_func)
|
||||||
result._extra_inputs = []
|
result._extra_inputs = []
|
||||||
|
result._op_def = fdef.signature
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _from_library(lib):
|
def from_library(lib):
|
||||||
"""Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
|
"""Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
|
||||||
|
|
||||||
This method handles assigning the correct gradient functions to each
|
This method handles assigning the correct gradient functions to each
|
||||||
|
@ -1287,7 +1287,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
gradients_impl.gradients([f1, f2, f3, f4], c)
|
gradients_impl.gradients([f1, f2, f3, f4], c)
|
||||||
|
|
||||||
library = g.as_graph_def().library
|
library = g.as_graph_def().library
|
||||||
new_funcs = function._from_library(library)
|
new_funcs = function.from_library(library)
|
||||||
|
|
||||||
def CheckNewFunc(func):
|
def CheckNewFunc(func):
|
||||||
new_func = [f for f in new_funcs if f.name == func.name]
|
new_func = [f for f in new_funcs if f.name == func.name]
|
||||||
@ -1303,7 +1303,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
|
|
||||||
def testFromLibraryEmptyLib(self):
|
def testFromLibraryEmptyLib(self):
|
||||||
library = function_pb2.FunctionDefLibrary()
|
library = function_pb2.FunctionDefLibrary()
|
||||||
self.assertEqual(len(function._from_library(library)), 0)
|
self.assertEqual(len(function.from_library(library)), 0)
|
||||||
|
|
||||||
def testFromLibraryMissingFuncDef(self):
|
def testFromLibraryMissingFuncDef(self):
|
||||||
|
|
||||||
@ -1327,7 +1327,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
"FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
|
"FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
|
||||||
function._from_library(library)
|
function.from_library(library)
|
||||||
|
|
||||||
# Create invalid function def that is missing F1 function def
|
# Create invalid function def that is missing F1 function def
|
||||||
library = function_pb2.FunctionDefLibrary()
|
library = function_pb2.FunctionDefLibrary()
|
||||||
@ -1337,7 +1337,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
"FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
|
"FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
|
||||||
function._from_library(library)
|
function.from_library(library)
|
||||||
|
|
||||||
def testFromLibraryCyclicGradFuncs(self):
|
def testFromLibraryCyclicGradFuncs(self):
|
||||||
|
|
||||||
@ -1366,7 +1366,7 @@ class FunctionsFromProtos(test.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
|
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
|
||||||
function._from_library(library)
|
function.from_library(library)
|
||||||
|
|
||||||
def testExperimentalAttrs(self):
|
def testExperimentalAttrs(self):
|
||||||
|
|
||||||
|
@ -442,11 +442,9 @@ def import_graph_def(graph_def,
|
|||||||
_ProcessNewOps(graph)
|
_ProcessNewOps(graph)
|
||||||
|
|
||||||
if graph_def.library and graph_def.library.function:
|
if graph_def.library and graph_def.library.function:
|
||||||
# pylint: disable=protected-access
|
functions = function.from_library(graph_def.library)
|
||||||
functions = function._from_library(graph_def.library)
|
|
||||||
for f in functions:
|
for f in functions:
|
||||||
f.add_to_graph(graph)
|
f.add_to_graph(graph)
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
# Treat input mappings that don't appear in the graph as an error, because
|
# Treat input mappings that don't appear in the graph as an error, because
|
||||||
# they are likely to be due to a typo.
|
# they are likely to be due to a typo.
|
||||||
|
@ -287,7 +287,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":builder",
|
":builder",
|
||||||
":constants",
|
":constants",
|
||||||
":loader",
|
":function_serialization",
|
||||||
":saved_object_graph_py",
|
":saved_object_graph_py",
|
||||||
":signature_constants",
|
":signature_constants",
|
||||||
":signature_def_utils",
|
":signature_def_utils",
|
||||||
@ -295,15 +295,20 @@ py_library(
|
|||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
|
"//tensorflow/python:tensor_spec",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:function",
|
"//tensorflow/python/eager:function",
|
||||||
"//tensorflow/python/training/checkpointable:base",
|
"//tensorflow/python/training/checkpointable:base",
|
||||||
|
"//tensorflow/python/training/checkpointable:tracking",
|
||||||
"//tensorflow/python/training/checkpointable:util",
|
"//tensorflow/python/training/checkpointable:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -330,8 +335,12 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":constants",
|
||||||
|
":function_deserialization",
|
||||||
":loader",
|
":loader",
|
||||||
":saved_object_graph_py",
|
":saved_object_graph_py",
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/python:function",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/training/checkpointable:tracking",
|
"//tensorflow/python/training/checkpointable:tracking",
|
||||||
@ -345,10 +354,34 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":load",
|
":load",
|
||||||
":save",
|
":save",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:tensor_spec",
|
"//tensorflow/python:tensor_spec",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
"//tensorflow/python/training/checkpointable:tracking",
|
"//tensorflow/python/training/checkpointable:tracking",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "function_serialization",
|
||||||
|
srcs = [
|
||||||
|
"function_serialization.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":saved_object_graph_py",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/eager:function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "function_deserialization",
|
||||||
|
srcs = [
|
||||||
|
"function_deserialization.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = ["//tensorflow/python/eager:def_function"],
|
||||||
|
)
|
||||||
|
46
tensorflow/python/saved_model/function_deserialization.py
Normal file
46
tensorflow/python/saved_model/function_deserialization.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tools for deserializing PolymorphicFunctions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
|
||||||
|
|
||||||
|
def recreate_polymorphic_function(
|
||||||
|
saved_polymorphic_function, defined_functions):
|
||||||
|
"""Creates a PolymorphicFunction which runs restored function definitions."""
|
||||||
|
@def_function.function
|
||||||
|
def restored_function(*args):
|
||||||
|
"""Calls a restored function."""
|
||||||
|
# Try calling each function, return a value from the first one whose
|
||||||
|
# signature matches.
|
||||||
|
# TODO(allenl): Consider re-populating the function cache directly.
|
||||||
|
# TODO(allenl): Functions saved with input_signatures should revive with
|
||||||
|
# input_signatures.
|
||||||
|
for monomorphic_function in saved_polymorphic_function.monomorphic_function:
|
||||||
|
try:
|
||||||
|
# TODO(allenl): Passing an explicit name here prevents invalid name
|
||||||
|
# errors. We should replace this with something based on the actual
|
||||||
|
# Python function name.
|
||||||
|
return defined_functions[monomorphic_function.concrete_function](
|
||||||
|
*args, name="imported_function")
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
raise AssertionError(
|
||||||
|
"Could not find matching function to call for arguments: %s" % (args,))
|
||||||
|
return restored_function
|
71
tensorflow/python/saved_model/function_serialization.py
Normal file
71
tensorflow/python/saved_model/function_serialization.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tools for serializing PolymorphicFunctions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.eager import function as defun_lib
|
||||||
|
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_polymorphic_function(function):
|
||||||
|
"""Represents a PolymorphicFunction in a SavedModel.
|
||||||
|
|
||||||
|
Adds `function`'s concrete functions to the current graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function: A `PolymorphicFunction` to serialize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An unserialized `SavedPolymorphicFunction` protocol buffer object.
|
||||||
|
"""
|
||||||
|
monomorphic_functions = []
|
||||||
|
for signature in function._cached_input_signatures: # pylint: disable=protected-access
|
||||||
|
if any(isinstance(arg, defun_lib.UnknownArgument) for arg in signature):
|
||||||
|
continue
|
||||||
|
concrete_function = function.get_concrete_function(*signature)
|
||||||
|
concrete_function.add_to_graph()
|
||||||
|
monomorphic_functions.append(
|
||||||
|
saved_object_graph_pb2.SavedMonomorphicFunction(
|
||||||
|
concrete_function=concrete_function.name))
|
||||||
|
return saved_object_graph_pb2.SavedPolymorphicFunction(
|
||||||
|
monomorphic_function=monomorphic_functions)
|
||||||
|
|
||||||
|
|
||||||
|
def add_polymorphic_functions_to_object_graph_proto(
|
||||||
|
checkpointable_objects, saved_object_graph):
|
||||||
|
"""Finds PolymorphicFunctions attached to objects and saves them."""
|
||||||
|
existing_objects = list(zip(checkpointable_objects, saved_object_graph.nodes))
|
||||||
|
for obj, obj_proto in existing_objects:
|
||||||
|
for attribute_name in dir(obj):
|
||||||
|
try:
|
||||||
|
attribute_value = getattr(obj, attribute_name, None)
|
||||||
|
except: # pylint: disable=bare-except
|
||||||
|
# We really don't want to throw an exception just because some object's
|
||||||
|
# attribute accessor is broken.
|
||||||
|
attribute_value = None
|
||||||
|
# TODO(allenl): Consider de-duplicating functions which are referenced
|
||||||
|
# from multiple attributes.
|
||||||
|
if isinstance(attribute_value, def_function.PolymorphicFunction):
|
||||||
|
function_node_id = len(saved_object_graph.nodes)
|
||||||
|
function_node = saved_object_graph.nodes.add()
|
||||||
|
function_node.function.CopyFrom(
|
||||||
|
_serialize_polymorphic_function(attribute_value))
|
||||||
|
reference = obj_proto.children.add()
|
||||||
|
reference.node_id = function_node_id
|
||||||
|
reference.local_name = attribute_name
|
@ -20,8 +20,10 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from tensorflow.python.framework import function as function_lib
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.saved_model import constants
|
from tensorflow.python.saved_model import constants
|
||||||
|
from tensorflow.python.saved_model import function_deserialization
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||||
@ -33,9 +35,17 @@ class _Loader(object):
|
|||||||
"""Helper class to load an object-based SavedModel."""
|
"""Helper class to load an object-based SavedModel."""
|
||||||
|
|
||||||
def __init__(self, object_graph_proto, saved_model_proto, export_dir):
|
def __init__(self, object_graph_proto, saved_model_proto, export_dir):
|
||||||
self._asset_file_def = saved_model_proto.meta_graphs[0].asset_file_def
|
meta_graph = saved_model_proto.meta_graphs[0]
|
||||||
|
self._asset_file_def = meta_graph.asset_file_def
|
||||||
self._proto = object_graph_proto
|
self._proto = object_graph_proto
|
||||||
self._export_dir = export_dir
|
self._export_dir = export_dir
|
||||||
|
self._defined_functions = {}
|
||||||
|
for defined_function in function_lib.from_library(
|
||||||
|
meta_graph.graph_def.library):
|
||||||
|
# TODO(allenl): Do we need to do name mapping here? Not quite sure what
|
||||||
|
# happens when loaded names collide with existing names.
|
||||||
|
defined_function.add_to_graph(None)
|
||||||
|
self._defined_functions[defined_function.name] = defined_function
|
||||||
self._load_all()
|
self._load_all()
|
||||||
|
|
||||||
def _load_all(self):
|
def _load_all(self):
|
||||||
@ -52,6 +62,7 @@ class _Loader(object):
|
|||||||
factory = {
|
factory = {
|
||||||
"user_object": lambda: self._recreate_user_object(proto.user_object),
|
"user_object": lambda: self._recreate_user_object(proto.user_object),
|
||||||
"asset": lambda: self._recreate_asset(proto.asset),
|
"asset": lambda: self._recreate_asset(proto.asset),
|
||||||
|
"function": lambda: self._recreate_function(proto.function)
|
||||||
}
|
}
|
||||||
kind = proto.WhichOneof("kind")
|
kind = proto.WhichOneof("kind")
|
||||||
if kind not in factory:
|
if kind not in factory:
|
||||||
@ -68,6 +79,10 @@ class _Loader(object):
|
|||||||
self._asset_file_def[proto.asset_file_def_index].filename)
|
self._asset_file_def[proto.asset_file_def_index].filename)
|
||||||
return tracking.TrackableAsset(filename)
|
return tracking.TrackableAsset(filename)
|
||||||
|
|
||||||
|
def _recreate_function(self, proto):
|
||||||
|
return function_deserialization.recreate_polymorphic_function(
|
||||||
|
proto, self._defined_functions)
|
||||||
|
|
||||||
|
|
||||||
def _load_saved_object_graph_proto(filename):
|
def _load_saved_object_graph_proto(filename):
|
||||||
with file_io.FileIO(filename, "rb") as f:
|
with file_io.FileIO(filename, "rb") as f:
|
||||||
@ -92,5 +107,4 @@ def load(export_dir):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Currently only SavedModels exported with `tf.saved_model.save` may be "
|
"Currently only SavedModels exported with `tf.saved_model.save` may be "
|
||||||
"imported. Other SavedModels may eventually be supported via load().")
|
"imported. Other SavedModels may eventually be supported via load().")
|
||||||
# TODO(allenl): load functions from the SavedModel into the eager context
|
|
||||||
return root
|
return root
|
||||||
|
@ -23,6 +23,7 @@ import tempfile
|
|||||||
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
@ -47,6 +48,7 @@ class LoadTest(test.TestCase):
|
|||||||
imported = load.load(save_dir)
|
imported = load.load(save_dir)
|
||||||
self.assertIs(imported.dep_three, imported.dep_two.dep)
|
self.assertIs(imported.dep_three, imported.dep_two.dep)
|
||||||
self.assertIsNot(imported.dep_one, imported.dep_two)
|
self.assertIsNot(imported.dep_one, imported.dep_two)
|
||||||
|
self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
|
||||||
|
|
||||||
def _make_asset(self, contents):
|
def _make_asset(self, contents):
|
||||||
filename = tempfile.mktemp(prefix=self.get_temp_dir())
|
filename = tempfile.mktemp(prefix=self.get_temp_dir())
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.saved_model import builder_impl
|
from tensorflow.python.saved_model import builder_impl
|
||||||
from tensorflow.python.saved_model import constants
|
from tensorflow.python.saved_model import constants
|
||||||
|
from tensorflow.python.saved_model import function_serialization
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.saved_model import signature_def_utils
|
from tensorflow.python.saved_model import signature_def_utils
|
||||||
@ -511,6 +512,9 @@ def _write_object_graph(root, export_dir, asset_file_def_index):
|
|||||||
for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
|
for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
|
||||||
_write_object_proto(obj, obj_proto, asset_file_def_index)
|
_write_object_proto(obj, obj_proto, asset_file_def_index)
|
||||||
|
|
||||||
|
function_serialization.add_polymorphic_functions_to_object_graph_proto(
|
||||||
|
checkpointable_objects, proto)
|
||||||
|
|
||||||
extra_asset_dir = os.path.join(
|
extra_asset_dir = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
|
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
|
||||||
|
@ -48,6 +48,7 @@ message SavedObject {
|
|||||||
oneof kind {
|
oneof kind {
|
||||||
SavedUserObject user_object = 4;
|
SavedUserObject user_object = 4;
|
||||||
SavedAsset asset = 5;
|
SavedAsset asset = 5;
|
||||||
|
SavedPolymorphicFunction function = 6;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,3 +72,13 @@ message SavedAsset {
|
|||||||
// `AssetFileDef.tensor_info`, MUST be ignored.
|
// `AssetFileDef.tensor_info`, MUST be ignored.
|
||||||
uint32 asset_file_def_index = 1;
|
uint32 asset_file_def_index = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A function with multiple signatures, possibly with non-Tensor arguments.
|
||||||
|
message SavedPolymorphicFunction {
|
||||||
|
repeated SavedMonomorphicFunction monomorphic_function = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SavedMonomorphicFunction {
|
||||||
|
// A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary
|
||||||
|
string concrete_function = 1;
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user