From 6bb91a0712aa9e312fafb8a81bf7d891b6e064dc Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 7 Dec 2017 15:36:40 -0800 Subject: [PATCH] Make import_graph_def import functions with C API enabled. PiperOrigin-RevId: 178306667 --- tensorflow/python/framework/importer.py | 16 ++++++++++++++++ tensorflow/python/framework/importer_test.py | 4 ---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 860e3fe7158..62765aff00e 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -462,6 +462,22 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, _ProcessNewOps(graph) + # Create _DefinedFunctions for any imported functions. + # + # We do this by creating _DefinedFunctions directly from `graph_def`, and + # adding them to `graph`. Adding an existing function to a TF_Graph is a + # no-op, so this only has the effect of updating the Python state (usually + # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). + # + # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph + # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph + if graph_def.library and graph_def.library.function: + # pylint: disable=protected-access + functions = function._from_library(graph_def.library) + for f in functions: + f.add_to_graph(graph) + # pylint: enable=protected-access + # TODO(skyewm): error if unused input map key if return_elements is None: diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index ee3cfbbd057..7bf13ba93d0 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -1110,8 +1110,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(987, a[0].get_attr("default_int")) def testFunctions(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - dtype = dtypes.float32 @function.Defun(dtype, dtype, dtype, dtype) def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument @@ -1189,8 +1187,6 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(sess.run("outer:0"), 21) def testImportInsideDefun(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - g = ops.Graph() with g.as_default(): @function.Defun()