Make import_graph_def import functions with C API enabled.

PiperOrigin-RevId: 178306667
This commit is contained in:
Skye Wanderman-Milne 2017-12-07 15:36:40 -08:00 committed by TensorFlower Gardener
parent 51fa3f7fef
commit 6bb91a0712
2 changed files with 16 additions and 4 deletions

View File

@ -462,6 +462,22 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
_ProcessNewOps(graph)
# Create _DefinedFunctions for any imported functions.
#
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
#
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
# TODO(skyewm): error if unused input map key
if return_elements is None:

View File

@ -1110,8 +1110,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(987, a[0].get_attr("default_int"))
def testFunctions(self):
if ops._USE_C_API: return # TODO(skyewm): make this work with C API
dtype = dtypes.float32
@function.Defun(dtype, dtype, dtype, dtype)
def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument
@ -1189,8 +1187,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(sess.run("outer:0"), 21)
def testImportInsideDefun(self):
if ops._USE_C_API: return # TODO(skyewm): make this work with C API
g = ops.Graph()
with g.as_default():
@function.Defun()