Make import_graph_def import functions with C API enabled.
PiperOrigin-RevId: 178306667
This commit is contained in:
parent
51fa3f7fef
commit
6bb91a0712
@ -462,6 +462,22 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
|
||||
_ProcessNewOps(graph)
|
||||
|
||||
# Create _DefinedFunctions for any imported functions.
|
||||
#
|
||||
# We do this by creating _DefinedFunctions directly from `graph_def`, and
|
||||
# adding them to `graph`. Adding an existing function to a TF_Graph is a
|
||||
# no-op, so this only has the effect of updating the Python state (usually
|
||||
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
|
||||
#
|
||||
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
|
||||
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
|
||||
if graph_def.library and graph_def.library.function:
|
||||
# pylint: disable=protected-access
|
||||
functions = function._from_library(graph_def.library)
|
||||
for f in functions:
|
||||
f.add_to_graph(graph)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# TODO(skyewm): error if unused input map key
|
||||
|
||||
if return_elements is None:
|
||||
|
@ -1110,8 +1110,6 @@ class ImportGraphDefTest(test.TestCase):
|
||||
self.assertEqual(987, a[0].get_attr("default_int"))
|
||||
|
||||
def testFunctions(self):
|
||||
if ops._USE_C_API: return # TODO(skyewm): make this work with C API
|
||||
|
||||
dtype = dtypes.float32
|
||||
@function.Defun(dtype, dtype, dtype, dtype)
|
||||
def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument
|
||||
@ -1189,8 +1187,6 @@ class ImportGraphDefTest(test.TestCase):
|
||||
self.assertEqual(sess.run("outer:0"), 21)
|
||||
|
||||
def testImportInsideDefun(self):
|
||||
if ops._USE_C_API: return # TODO(skyewm): make this work with C API
|
||||
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@function.Defun()
|
||||
|
Loading…
Reference in New Issue
Block a user