Use _DefinedFunction.add_to_graph() call inside importer.

Before, Graph._add_function() call was used to import function library from
GraphDef. This private function lacks hash value comparison and importing
the inner function library used by the function gets improted.

PiperOrigin-RevId: 156933480
This commit is contained in:
A. Unique TensorFlower 2017-05-23 17:15:51 -07:00 committed by TensorFlower Gardener
parent da92fbd8b0
commit afcd75baa4
2 changed files with 30 additions and 1 deletions

View File

@ -247,7 +247,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
# not have a scoped name or namespace scheme.
functions = function._from_library(graph_def.library)
for f in functions:
g._add_function(f)
f.add_to_graph(g)
op_dict[f.name] = f.definition.signature
# pylint: enable=protected-access

View File

@ -1015,6 +1015,35 @@ class ImportGraphDefTest(test.TestCase):
z_val = z.eval()
self.assertEqual(z_val, -2.0)
def testImportGraphWithFunctionTwice(self):
g = ops.Graph()
with g.as_default():
@function.Defun()
def Add2(x, y):
return math_ops.add(x, y)
x = array_ops.placeholder(dtype=dtypes.float32, name="x")
y = array_ops.placeholder(dtype=dtypes.float32, name="y")
_ = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg
gdef = g.as_graph_def()
x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
input_map = {"x:0": x, "y:0": y}
with ops.name_scope("first"):
z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
with ops.name_scope("second"):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
with self.test_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
if __name__ == "__main__":
test.main()