1285 lines
49 KiB
Python
1285 lines
49 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.python.framework.importer."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
from google.protobuf import text_format
|
|
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.framework import op_def_pb2
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import device
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import function
|
|
from tensorflow.python.framework import importer
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.framework import versions
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import gradients_impl
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import nn_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variables
|
|
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
class ImportGraphDefTest(test.TestCase):
|
|
|
|
def _MakeGraphDef(self,
|
|
text,
|
|
producer=versions.GRAPH_DEF_VERSION,
|
|
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER):
|
|
text = "versions: { producer: %d min_consumer: %d };\n%s" % (producer,
|
|
min_consumer,
|
|
text)
|
|
ret = graph_pb2.GraphDef()
|
|
text_format.Merge(text, ret)
|
|
return ret
|
|
|
|
def testBasic(self):
|
|
with ops.Graph().as_default():
|
|
a, b, c, d = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutputFloatOutput' }
|
|
node { name: 'B' op: 'ListOutput'
|
|
attr { key: 'T'
|
|
value { list { type: DT_INT32 type: DT_FLOAT } } } }
|
|
node { name: 'C' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:0' input: 'B:0' }
|
|
node { name: 'D' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_FLOAT } }
|
|
input: 'A:1' input: 'B:1' }
|
|
"""),
|
|
return_elements=["A", "B", "C", "D"],
|
|
name="import")
|
|
|
|
# Assert that the import process creates distinct tensors.
|
|
self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
|
|
self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
|
|
self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
|
|
self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
|
|
self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
|
|
self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
|
|
|
|
# Assert that the ops are connected according to the GraphDef topology.
|
|
self.assertEqual(c.inputs[0], a.outputs[0])
|
|
self.assertEqual(c.inputs[1], b.outputs[0])
|
|
self.assertEqual(d.inputs[0], a.outputs[1])
|
|
self.assertEqual(d.inputs[1], b.outputs[1])
|
|
|
|
# Check the types of the returned ops and tensors.
|
|
self.assertEqual(a.type, "IntOutputFloatOutput")
|
|
self.assertEqual(b.type, "ListOutput")
|
|
self.assertEqual(c.type, "ListInput")
|
|
self.assertEqual(d.type, "ListInput")
|
|
self.assertEqual(a.outputs[0].dtype, dtypes.int32)
|
|
self.assertEqual(a.outputs[1].dtype, dtypes.float32)
|
|
self.assertEqual(b.outputs[0].dtype, dtypes.int32)
|
|
self.assertEqual(b.outputs[1].dtype, dtypes.float32)
|
|
|
|
# Check the names of the returned ops.
|
|
self.assertEqual(a.name, "import/A")
|
|
self.assertEqual(b.name, "import/B")
|
|
self.assertEqual(c.name, "import/C")
|
|
self.assertEqual(d.name, "import/D")
|
|
|
|
# Check that the op_def is still available.
|
|
self.assertNotEqual(None, a.op_def)
|
|
|
|
def testMultipleImport(self):
|
|
graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
node { name: 'B' op: 'IntInput' input: 'A:0' }
|
|
""")
|
|
|
|
with ops.Graph().as_default():
|
|
# Initial import
|
|
a, b = importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=["A", "B"],
|
|
name="")
|
|
self.assertEqual(a.name, "A")
|
|
self.assertEqual(b.name, "B")
|
|
self.assertEqual(list(b.inputs), [a.outputs[0]])
|
|
|
|
# Repeat the same import
|
|
a1, b1 = importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=["A", "B"],
|
|
name="")
|
|
self.assertEqual(a1.name, "A_1")
|
|
self.assertEqual(b1.name, "B_1")
|
|
self.assertEqual(list(b1.inputs), [a1.outputs[0]])
|
|
|
|
# Repeat the same import again
|
|
a2, b2 = importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=["A", "B"],
|
|
name="")
|
|
self.assertEqual(a2.name, "A_2")
|
|
self.assertEqual(b2.name, "B_2")
|
|
self.assertEqual(list(b2.inputs), [a2.outputs[0]])
|
|
|
|
# Import with an already-used name
|
|
a3, b3 = importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=["A", "B"],
|
|
name="A")
|
|
self.assertEqual(a3.name, "A_3/A")
|
|
self.assertEqual(b3.name, "A_3/B")
|
|
self.assertEqual(list(b3.inputs), [a3.outputs[0]])
|
|
|
|
# Import with an already-used name but with a '/' to indicate an
|
|
# "absolute" name scope (see the Graph.name_scope docstring).
|
|
a_a, a_b = importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=["A", "B"],
|
|
name="A/")
|
|
self.assertEqual(a_a.name, "A/A")
|
|
self.assertEqual(a_b.name, "A/B")
|
|
self.assertEqual(list(a_b.inputs), [a_a.outputs[0]])
|
|
|
|
# Repeat the same import.
|
|
a_a1, a_b1 = importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=["A", "B"],
|
|
name="A/")
|
|
self.assertEqual(a_a1.name, "A/A_1")
|
|
self.assertEqual(a_b1.name, "A/B_1")
|
|
self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]])
|
|
|
|
# Import with existing de-duped node names
|
|
a1_1, b1_1 = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A_1' op: 'IntOutput' }
|
|
node { name: 'B_1' op: 'IntInput' input: 'A_1:0' }
|
|
"""),
|
|
return_elements=["A_1", "B_1"],
|
|
name="")
|
|
self.assertEqual(a1_1.name, "A_1_1")
|
|
self.assertEqual(b1_1.name, "B_1_1")
|
|
self.assertEqual(list(b1_1.inputs), [a1_1.outputs[0]])
|
|
|
|
# Create a name scope and then import node with same name
|
|
with ops.name_scope("foo"):
|
|
constant_op.constant(1)
|
|
foo, = importer.import_graph_def(
|
|
self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"),
|
|
return_elements=["foo"],
|
|
name="")
|
|
self.assertEqual(foo.name, "foo_1")
|
|
|
|
# Imported node name can't conflict with intermediate name scope (but can
|
|
# conflict with outer scope and full name scope)
|
|
with ops.name_scope("outer"):
|
|
with ops.name_scope("inner"):
|
|
c = constant_op.constant(1, name="c")
|
|
self.assertEqual(c.op.name, "outer/inner/c")
|
|
|
|
outer, inner, new_c, outer_inner, outer_inner_c = (
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef(
|
|
"node { name: 'outer' op: 'IntOutput' }"
|
|
"node { name: 'inner' op: 'IntOutput' }"
|
|
"node { name: 'c' op: 'IntOutput' }"
|
|
"node { name: 'outer/inner' op: 'IntOutput' }"
|
|
"node { name: 'outer/inner/c' op: 'IntOutput' }"),
|
|
return_elements=["outer", "inner", "c", "outer/inner",
|
|
"outer/inner/c"],
|
|
name=""))
|
|
self.assertEqual(outer.name, "outer_1")
|
|
self.assertEqual(inner.name, "inner")
|
|
self.assertEqual(new_c.name, "c")
|
|
self.assertEqual(outer_inner.name, "outer/inner_1")
|
|
self.assertEqual(outer_inner_c.name, "outer/inner/c_1")
|
|
|
|
def testEmptyNameScope(self):
|
|
with ops.Graph().as_default():
|
|
# Create name scope but don't create any ops with it
|
|
with ops.name_scope("foo"):
|
|
pass
|
|
|
|
# Import graph def that uses name scope name
|
|
op, = importer.import_graph_def(
|
|
self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"),
|
|
return_elements=["foo"],
|
|
name="")
|
|
|
|
self.assertEqual(op.name, "foo")
|
|
|
|
def testInputMap(self):
|
|
with ops.Graph().as_default():
|
|
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
|
|
feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
|
|
|
|
a, b, c, d = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'TwoIntOutputs' }
|
|
node { name: 'B' op: 'TwoIntOutputs' }
|
|
node { name: 'C' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:0' input: 'B:0' }
|
|
node { name: 'D' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:1' input: 'B:1' }
|
|
"""),
|
|
input_map={"A:0": feed_a_0,
|
|
"B:1": feed_b_1},
|
|
return_elements=["A", "B", "C", "D"])
|
|
|
|
self.assertEqual(c.inputs[0], feed_a_0)
|
|
self.assertEqual(c.inputs[1], b.outputs[0])
|
|
self.assertEqual(d.inputs[0], a.outputs[1])
|
|
self.assertEqual(d.inputs[1], feed_b_1)
|
|
|
|
def testInputMapBytes(self):
|
|
with ops.Graph().as_default():
|
|
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
|
|
feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
|
|
|
|
a, b, c, d = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'TwoIntOutputs' }
|
|
node { name: 'B' op: 'TwoIntOutputs' }
|
|
node { name: 'C' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:0' input: 'B:0' }
|
|
node { name: 'D' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:1' input: 'B:1' }
|
|
"""),
|
|
input_map={b"A:0": feed_a_0,
|
|
b"B:1": feed_b_1},
|
|
return_elements=[b"A", b"B", b"C", b"D"])
|
|
|
|
self.assertEqual(c.inputs[0], feed_a_0)
|
|
self.assertEqual(c.inputs[1], b.outputs[0])
|
|
self.assertEqual(d.inputs[0], a.outputs[1])
|
|
self.assertEqual(d.inputs[1], feed_b_1)
|
|
|
|
def testInputMapUnicode(self):
|
|
with ops.Graph().as_default():
|
|
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
|
|
feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
|
|
|
|
a, b, c, d = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'TwoIntOutputs' }
|
|
node { name: 'B' op: 'TwoIntOutputs' }
|
|
node { name: 'C' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:0' input: 'B:0' }
|
|
node { name: 'D' op: 'ListInput'
|
|
attr { key: 'N' value { i: 2 } }
|
|
attr { key: 'T' value { type: DT_INT32 } }
|
|
input: 'A:1' input: 'B:1' }
|
|
"""),
|
|
input_map={u"A:0": feed_a_0,
|
|
u"B:1": feed_b_1},
|
|
return_elements=[u"A", u"B", u"C", u"D"])
|
|
|
|
self.assertEqual(c.inputs[0], feed_a_0)
|
|
self.assertEqual(c.inputs[1], b.outputs[0])
|
|
self.assertEqual(d.inputs[0], a.outputs[1])
|
|
self.assertEqual(d.inputs[1], feed_b_1)
|
|
|
|
def testImplicitZerothOutput(self):
|
|
with ops.Graph().as_default():
|
|
a, b = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'TwoIntOutputs' }
|
|
node { name: 'B' op: 'IntInput' input: 'A' }
|
|
"""),
|
|
return_elements=["A", "B"])
|
|
|
|
self.assertEqual(b.inputs[0], a.outputs[0])
|
|
|
|
def testInputMapImplicitZerothOutput(self):
|
|
with ops.Graph().as_default():
|
|
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
|
|
b, = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'TwoIntOutputs' }
|
|
node { name: 'B' op: 'IntInput' input: 'A:0' }
|
|
"""),
|
|
input_map={"A": feed_a_0},
|
|
return_elements=["B"])
|
|
|
|
self.assertEqual(b.inputs[0], feed_a_0)
|
|
|
|
def testWithControlDependency(self):
|
|
with ops.Graph().as_default():
|
|
a, b = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
node { name: 'B' op: 'None' input: '^A' }
|
|
"""),
|
|
return_elements=["A", "B"])
|
|
|
|
self.assertEqual(b.control_inputs, [a])
|
|
|
|
def testWithRefs(self):
|
|
with ops.Graph().as_default():
|
|
a, b, c, d = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'RefOutput' }
|
|
node { name: 'B' op: 'IntOutput' }
|
|
node { name: 'C' op: 'TwoIntInputs' input: 'A:0' input: 'B:0' }
|
|
node { name: 'D' op: 'RefInputIntInput' input: 'A:0' input: 'B:0' }
|
|
"""),
|
|
return_elements=["A", "B", "C", "D"])
|
|
|
|
self.assertEqual(c.inputs[0], a.outputs[0])
|
|
self.assertEqual(c.inputs[1], b.outputs[0])
|
|
self.assertEqual(d.inputs[0], a.outputs[0])
|
|
self.assertEqual(d.inputs[1], b.outputs[0])
|
|
|
|
self.assertEqual(a.outputs[0].dtype, dtypes.int32_ref)
|
|
self.assertEqual(c._input_types, [dtypes.int32, dtypes.int32])
|
|
self.assertEqual(c.outputs, [])
|
|
self.assertEqual(d._input_types, [dtypes.int32_ref, dtypes.int32])
|
|
self.assertEqual(d.outputs, [])
|
|
|
|
def testResources(self):
|
|
# Produce GraphDef containing a ops producing and consuming resources.
|
|
graph = ops.Graph()
|
|
with graph.as_default():
|
|
var = resource_variable_ops.ResourceVariable(1.0)
|
|
var_assign = var.assign(2.0)
|
|
# Use an op that requires handle shape to be set.
|
|
var_shape = resource_variable_ops.variable_shape(var.handle)
|
|
init = variables.global_variables_initializer()
|
|
graph_def = graph.as_graph_def()
|
|
|
|
# Import the GraphDef.
|
|
with ops.Graph().as_default():
|
|
# pylint: disable=unused-variable
|
|
imported_var, imported_assign, imported_shape, imported_init = (
|
|
importer.import_graph_def(
|
|
graph_def,
|
|
return_elements=[var.name, var_assign.name, var_shape.name,
|
|
init.name]))
|
|
|
|
# Make sure the handle shape is set on the imported variable.
|
|
new_var_shape = resource_variable_ops.variable_shape(imported_var)
|
|
# pylint: enable=unused-variable
|
|
|
|
# Run the imported graph.
|
|
# TODO(b/76173421): make this work (currently DCHECKS)
|
|
# with self.cached_session() as sess:
|
|
# self.evaluate(imported_init)
|
|
# self.assertEqual(self.evaluate(imported_var), 1.0)
|
|
# self.assertEqual(self.evaluate(imported_assign), 2.0)
|
|
# self.assertEqual(list(self.evaluate(imported_shape)), [])
|
|
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
|
|
|
|
def testWhileLoop(self):
|
|
# Produce GraphDef containing while loop.
|
|
graph = ops.Graph()
|
|
with graph.as_default():
|
|
r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
|
|
# Add an op that consumes the while loop output.
|
|
math_ops.add(r, 1)
|
|
graph_def = graph.as_graph_def()
|
|
|
|
# Import the GraphDef and make sure it runs.
|
|
with ops.Graph().as_default():
|
|
imported_r, = importer.import_graph_def(graph_def,
|
|
return_elements=[r.name])
|
|
self.assertEqual(imported_r.name, "import/" + r.name)
|
|
with self.cached_session() as sess:
|
|
self.assertEqual(self.evaluate(imported_r), 10)
|
|
|
|
def testImportWhileLoopInCond(self):
|
|
# Produce GraphDef containing while loop.
|
|
graph = ops.Graph()
|
|
with graph.as_default():
|
|
r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
|
|
graph_def = graph.as_graph_def()
|
|
|
|
# Import the GraphDef inside a cond and make sure it runs.
|
|
with ops.Graph().as_default():
|
|
|
|
def ImportFn():
|
|
return importer.import_graph_def(graph_def, return_elements=[r.name])[0]
|
|
|
|
pred = array_ops.placeholder(dtypes.bool)
|
|
out = control_flow_ops.cond(pred, ImportFn,
|
|
lambda: constant_op.constant(1))
|
|
with self.cached_session() as sess:
|
|
self.assertEqual(sess.run(out, {pred: True}), 10)
|
|
self.assertEqual(sess.run(out, {pred: False}), 1)
|
|
|
|
def testImportWhileLoopInWhileLoop(self):
|
|
self.skipTest("b/111757448")
|
|
# Produce GraphDef containing while loop.
|
|
graph = ops.Graph()
|
|
with graph.as_default():
|
|
r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0])
|
|
graph_def = graph.as_graph_def()
|
|
|
|
# Import the GraphDef inside another loop and make sure it runs.
|
|
with ops.Graph().as_default():
|
|
|
|
def ImportFn(_):
|
|
return importer.import_graph_def(graph_def, return_elements=[r.name])[0]
|
|
|
|
out = control_flow_ops.while_loop(
|
|
lambda i: i < 2, ImportFn, [0],
|
|
shape_invariants=[tensor_shape.TensorShape(None)])
|
|
with self.cached_session() as sess:
|
|
self.assertEqual(self.evaluate(out), 10)
|
|
|
|
def testTypeMismatchInGraphDef(self):
|
|
# TODO(skyewm): improve error message
|
|
error_msg = ("Input 0 of node import/B was passed int32 from import/A:0 "
|
|
"incompatible with expected float.")
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
node { name: 'B' op: 'FloatInput' input: 'A:0' }
|
|
"""))
|
|
|
|
def testShapeAllowlistViolation(self):
|
|
# L2 loss produces a scalar shape, but the graph
|
|
# has the wrong shape, so raise an error.
|
|
with ops.Graph().as_default():
|
|
with self.assertRaises(ValueError) as e:
|
|
_ = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'FloatOutput' }
|
|
node { name: 'B' op: 'L2Loss'
|
|
input: 'A:0'
|
|
attr { key: 'T' value { type: DT_FLOAT } }
|
|
attr { key: '_output_shapes'
|
|
value { list { shape { dim { size: 43 } } } } } }
|
|
"""),
|
|
return_elements=["B"],
|
|
name="import")
|
|
self.assertTrue(
|
|
"Shapes () and (43,) are not compatible" in str(e.exception))
|
|
|
|
def testInvalidSignatureTooManyInputsInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
# TODO(skyewm): improve error message
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"NodeDef expected inputs '' do not match 1 inputs specified"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
node { name: 'B' op: 'None' input: 'A:0' }
|
|
"""))
|
|
|
|
def testInvalidSignatureNotEnoughInputsInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
# TODO(skyewm): improve error message
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"NodeDef expected inputs 'int32, float' do not match 1 inputs "
|
|
"specified"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
node { name: 'B' op: 'IntInputFloatInput' input: 'A:0' }
|
|
"""))
|
|
|
|
def testMissingInputOpInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Node 'B': Unknown input node 'A:0'"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'B' op: 'FloatInput' input: 'A:0' }
|
|
"""))
|
|
|
|
def testMissingInputOpInGraphDefButAppearsInInputMap(self):
|
|
with ops.Graph().as_default():
|
|
feed_a_0 = constant_op.constant(5.0)
|
|
b, = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'B' op: 'FloatInput' input: 'A:0' }
|
|
"""),
|
|
input_map={"A:0": feed_a_0},
|
|
return_elements=["B"])
|
|
self.assertEqual(b.inputs[0], feed_a_0)
|
|
|
|
def testMissingInputTensorInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Node 'B': Connecting to invalid output 1 of source node A "
|
|
"which has 1 outputs"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'FloatOutput' }
|
|
node { name: 'B' op: 'FloatInput' input: 'A:1' }
|
|
"""))
|
|
|
|
def testMissingControlInputInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(ValueError,
|
|
r"Node 'B': Unknown input node '\^A'"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'B' op: 'None' input: '^A' }
|
|
"""))
|
|
|
|
def testInvalidTensorNameOutputIndexInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Node 'B': Unknown input node 'A:B'"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'B' op: 'None' input: 'A:B' }
|
|
"""))
|
|
|
|
def testInvalidTensorNameInGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Node 'B': Unknown input node 'A:B:0'"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'B' op: 'None' input: 'A:B:0' }
|
|
"""))
|
|
|
|
def testMissingReturnOperation(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Requested return node 'B' not found in graph def"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
"""),
|
|
return_elements=["B"])
|
|
|
|
def testMissingReturnTensor(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Invalid return output 1 of node 'A', which has 1 output\(s\)"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
"""),
|
|
return_elements=["A:1"])
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Requested return tensor 'B:0' not found in graph def"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
"""),
|
|
return_elements=["B:0"])
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Cannot convert 'A:B:0' to a tensor name."):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
"""),
|
|
return_elements=["A:B:0"])
|
|
|
|
def testMissingInputMap(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Attempted to map inputs that were not found in graph_def: \[B:0\]"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
"""),
|
|
input_map={"B:0": constant_op.constant(5.0)})
|
|
|
|
def testInputMapUnusedAsInput(self):
|
|
with ops.Graph().as_default():
|
|
# Mapping an unused node output should succeed.
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
"""),
|
|
input_map={"A:0": constant_op.constant(5.0)})
|
|
|
|
# Mapping a non-existent output of an existing node should fail.
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Attempted to map inputs that were not found in graph_def: \[A:2\]"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
"""),
|
|
input_map={"A:2": constant_op.constant(5.0)})
|
|
|
|
def testInputMapTypeMismatch(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Input 0 of node import/B was passed float from Const:0 "
|
|
"incompatible with expected int32."):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
node { name: 'B' op: 'IntInput' input: 'A:0' }
|
|
"""),
|
|
input_map={"A:0": constant_op.constant(5.0)})
|
|
|
|
def testNoReturns(self):
|
|
with ops.Graph().as_default() as g:
|
|
ret = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
"""))
|
|
self.assertEqual(ret, None)
|
|
|
|
a = g.get_operation_by_name("import/A")
|
|
self.assertEqual(a.type, "None")
|
|
|
|
def testOverrideNamePrefix(self):
|
|
with ops.Graph().as_default():
|
|
a, = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
"""),
|
|
return_elements=["A"],
|
|
name="imported_graph")
|
|
self.assertEqual(a.name, "imported_graph/A")
|
|
|
|
def testDefaultNamePrefix(self):
|
|
with ops.Graph().as_default():
|
|
a, = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
"""),
|
|
return_elements=["A"],
|
|
name=None)
|
|
self.assertEqual(a.name, "import/A")
|
|
|
|
def testNamePrefixColocationAttrs(self):
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
node { name: 'B' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }""")
|
|
|
|
with ops.Graph().as_default():
|
|
b, = importer.import_graph_def(
|
|
original_graph_def, return_elements=["B"], name="imported_graph")
|
|
self.assertTrue("_class" in b.node_def.attr)
|
|
self.assertProtoEquals(
|
|
"list { s: 'loc:@imported_graph/A' }",
|
|
b.node_def.attr["_class"])
|
|
|
|
def testColocationAndDevice(self):
|
|
# A and B are colocated, device set on A.
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' device: '/device:CPU:0' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }
|
|
node { name: 'B' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }""")
|
|
|
|
with ops.Graph().as_default():
|
|
a, b = importer.import_graph_def(original_graph_def,
|
|
return_elements=["A", "B"],
|
|
name="")
|
|
self.assertEqual(a.device, "/device:CPU:0")
|
|
self.assertEqual(b.device, "/device:CPU:0")
|
|
self.assertEqual(a.colocation_groups(), [b"loc:@A"])
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@A"])
|
|
|
|
# A and B are colocated, device set on B.
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }
|
|
node { name: 'B' op: 'None' device: '/device:CPU:0' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }""")
|
|
|
|
with ops.Graph().as_default():
|
|
a, b = importer.import_graph_def(original_graph_def,
|
|
return_elements=["A", "B"],
|
|
name="")
|
|
# TODO(skyewm): this behavior seems inconsistent with the above. Why is
|
|
# B's device ignored?
|
|
self.assertEqual(a.device, "")
|
|
self.assertEqual(b.device, "")
|
|
self.assertEqual(a.colocation_groups(), [b"loc:@A"])
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@A"])
|
|
|
|
def testColocationWithDeviceFn(self):
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }
|
|
node { name: 'B' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }""")
|
|
|
|
# A device function that places "A" on one device and "B" on
|
|
# another device. Because B is colocated with A, we test that B's
|
|
# device function is overridden by A.
|
|
def CustomDeviceFn(op):
|
|
if "A" in op.name:
|
|
return "/device:A:0"
|
|
else:
|
|
return "/device:B:0"
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(CustomDeviceFn):
|
|
a, b = importer.import_graph_def(original_graph_def,
|
|
return_elements=["A", "B"],
|
|
name="imported_graph")
|
|
self.assertEqual(a.device, "/device:A:0")
|
|
self.assertEqual(b.device, "/device:A:0")
|
|
self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
|
|
# Test a scenario where 'A' doesn't get a device; 'A' should not have a
|
|
# device, but during runtime will get colocated with 'B' because of the
|
|
# colocation attribute. B's device function is still overridden by A.
|
|
def BDeviceFn(op):
|
|
if "B" in op.name:
|
|
return "/device:B:0"
|
|
return ""
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(BDeviceFn):
|
|
a, b = importer.import_graph_def(original_graph_def,
|
|
return_elements=["A", "B"],
|
|
name="imported_graph")
|
|
self.assertEqual(a.device, "")
|
|
self.assertEqual(b.device, "")
|
|
self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
|
|
# Only A gets a device, so B inherits it implicitly.
|
|
def ADeviceFn(op):
|
|
if "A" in op.name:
|
|
return "/device:A:0"
|
|
return ""
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(ADeviceFn):
|
|
a, b = importer.import_graph_def(original_graph_def,
|
|
return_elements=["A", "B"],
|
|
name="imported_graph")
|
|
self.assertEqual(a.device, "/device:A:0")
|
|
self.assertEqual(b.device, "/device:A:0")
|
|
self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
|
|
def testMultipleColocationWithDeviceFn(self):
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None'}
|
|
node { name: 'B' op: 'None'}
|
|
node { name: 'C' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' s: 'loc:@B' } }
|
|
} }""")
|
|
|
|
# A device function that places "B" on a device, and "A" is empty.
|
|
#
|
|
# B and C should contain "/device:B". A will not right now. But
|
|
# because of the colocation property, at runtime it would be
|
|
# placed with B and C.
|
|
def CustomDeviceFn(op):
|
|
if "B" in op.name:
|
|
return "/device:B:0"
|
|
return ""
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(CustomDeviceFn):
|
|
a, b, c = importer.import_graph_def(original_graph_def,
|
|
return_elements=["A", "B", "C"],
|
|
name="imported_graph")
|
|
self.assertEqual(a.device, "")
|
|
self.assertEqual(b.device, "/device:B:0")
|
|
self.assertEqual(c.device, "/device:B:0")
|
|
self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/B"])
|
|
self.assertEqual(c.colocation_groups(),
|
|
[b"loc:@imported_graph/A", b"loc:@imported_graph/B"])
|
|
|
|
def testNamePrefixColocationAttrsMultipleImport(self):
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'A' op: 'None' }
|
|
node { name: 'B' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }""")
|
|
|
|
with ops.Graph().as_default():
|
|
a, b = importer.import_graph_def(
|
|
original_graph_def, return_elements=["A", "B"], name="")
|
|
a_1, b_1 = importer.import_graph_def(
|
|
original_graph_def, return_elements=["A", "B"], name="")
|
|
|
|
self.assertEqual(a.name, "A")
|
|
self.assertEqual(b.name, "B")
|
|
self.assertEqual(b.colocation_groups(), [b"loc:@A"])
|
|
|
|
self.assertEqual(a_1.name, "A_1")
|
|
self.assertEqual(b_1.name, "B_1")
|
|
self.assertEqual(b_1.colocation_groups(), [b"loc:@A_1"])
|
|
|
|
def testNamePrefixColocationAttrsNotFound(self):
|
|
original_graph_def = self._MakeGraphDef("""
|
|
node { name: 'B' op: 'None' attr {
|
|
key: '_class'
|
|
value { list { s: 'loc:@A' } }
|
|
} }""")
|
|
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Node 'B' expects to be colocated with unknown node 'A'"):
|
|
importer.import_graph_def(
|
|
original_graph_def, return_elements=["B"], name="imported_graph")
|
|
|
|
def testEmptyGraph(self):
|
|
with ops.Graph().as_default() as g:
|
|
init_version = g.version
|
|
importer.import_graph_def(self._MakeGraphDef(""))
|
|
self.assertEqual(init_version, g.version)
|
|
|
|
def testInvalidInputForGraphDef(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaises(TypeError) as e:
|
|
importer.import_graph_def("")
|
|
self.assertEqual("graph_def must be a GraphDef proto.", str(e.exception))
|
|
|
|
def testInvalidInputForInputMap(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaises(TypeError) as e:
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
|
|
self.assertEqual("input_map must be a dictionary mapping strings to "
|
|
"Tensor objects.", str(e.exception))
|
|
graph_def = self._MakeGraphDef("""
|
|
node { name: 'a' op: 'Placeholder'
|
|
attr { key: 'dtype' value { type: DT_FLOAT } }}
|
|
node { name: 'id' op: 'Identity' input: 'a:0'
|
|
attr { key: 'T' value { type: DT_FLOAT } }}""")
|
|
with ops.Graph().as_default():
|
|
with self.assertRaises(ValueError) as e:
|
|
importer.import_graph_def(
|
|
graph_def,
|
|
input_map={"a:0": variables.Variable(5.0)},
|
|
name="")
|
|
self.assertStartsWith(str(e.exception),
|
|
"tf.import_graph_def() requires a non-empty `name` "
|
|
"if `input_map` contains non-Tensor values.")
|
|
with ops.Graph().as_default():
|
|
t, = importer.import_graph_def(
|
|
graph_def,
|
|
input_map={"a:0": constant_op.constant(5.0)},
|
|
name="",
|
|
return_elements=["id:0"])
|
|
with self.cached_session():
|
|
self.assertEqual(5.0, self.evaluate(t))
|
|
|
|
def testInvalidInputForReturnOperations(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(TypeError,
|
|
"return_elements must be a list of strings."):
|
|
importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7])
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
"Cannot convert 'a:b:c' to a tensor name."):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef(""), return_elements=["a:b:c"])
|
|
|
|
def testDuplicateOperationNames(self):
|
|
with self.assertRaisesRegex(ValueError, "Node 'A' is not unique"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'IntOutput' }
|
|
node { name: 'B' op: 'IntOutput' }
|
|
node { name: 'A' op: 'IntOutput' }
|
|
"""))
|
|
|
|
@test_util.run_v1_only("v1 Tensor doesn't have attribute 'numpy'")
|
|
def testWithExtensionAndAttr(self):
|
|
with ops.Graph().as_default() as g:
|
|
c = constant_op.constant(5.0, dtype=dtypes.float32, name="c")
|
|
array_ops.stack([c, c], name="pack")
|
|
gdef = g.as_graph_def()
|
|
|
|
with self.cached_session():
|
|
pack, = importer.import_graph_def(gdef, return_elements=["pack"])
|
|
self.assertAllEqual(pack.outputs[0], [5.0, 5.0])
|
|
|
|
def testWithDevice(self):
|
|
with ops.Graph().as_default() as g:
|
|
# No device.
|
|
a = constant_op.constant(3.0, name="a")
|
|
|
|
with ops.device("/cpu:0"):
|
|
b = constant_op.constant(4.0, name="b")
|
|
with ops.device("/job:worker"):
|
|
c = constant_op.constant(5.0, name="c")
|
|
|
|
gdef = g.as_graph_def()
|
|
|
|
with ops.Graph().as_default():
|
|
a2, b2, c2 = importer.import_graph_def(
|
|
gdef, return_elements=["a", "b", "c"])
|
|
self.assertEqual(a.device, a2.device)
|
|
self.assertEqual(b.device, b2.device)
|
|
self.assertEqual(c.device, c2.device)
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(device.merge_device("/task:0")):
|
|
a3, b3, c3 = importer.import_graph_def(
|
|
gdef, return_elements=["a", "b", "c"])
|
|
self.assertEqual("/task:0", a3.device)
|
|
self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized.
|
|
self.assertEqual(c.device + "/task:0", c3.device)
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(device.merge_device("/job:ps")):
|
|
a4, b4, c4 = importer.import_graph_def(
|
|
gdef, return_elements=["a", "b", "c"])
|
|
self.assertEqual("/job:ps", a4.device)
|
|
self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized.
|
|
self.assertEqual(c.device, c4.device) # worker overrides ps.
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.device(device.merge_device("/device:GPU:0")):
|
|
a5, b5, c5 = importer.import_graph_def(
|
|
gdef, return_elements=["a", "b", "c"])
|
|
self.assertEqual("/device:GPU:0", a5.device)
|
|
self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu.
|
|
self.assertEqual(c.device + "/device:GPU:0", c5.device)
|
|
|
|
def testWithDeviceFunctionDependingOnInputs(self):
|
|
with ops.Graph().as_default() as g:
|
|
with ops.device("/job:ps"):
|
|
v1 = constant_op.constant(1.0)
|
|
v2 = constant_op.constant(1.0)
|
|
_ = v1 + v2
|
|
_ = v1 - v2
|
|
_ = array_ops.identity(v1)
|
|
gdef = g.as_graph_def()
|
|
|
|
# We'll use the following device function to observe ops with two inputs.
|
|
ops_with_two_inputs = []
|
|
|
|
def InputCounter(op):
|
|
if len(op.inputs) == 2:
|
|
ops_with_two_inputs.append(op)
|
|
return ""
|
|
|
|
with ops.Graph().as_default() as g:
|
|
with ops.device(InputCounter):
|
|
importer.import_graph_def(gdef)
|
|
|
|
# We expect to see the add and subtract, but not identity.
|
|
self.assertEqual(2, len(ops_with_two_inputs))
|
|
|
|
def testGradient(self):
|
|
with ops.Graph().as_default() as g:
|
|
inputs = array_ops.placeholder(
|
|
dtypes.float32, shape=[None, 100], name="input")
|
|
weights = array_ops.placeholder(
|
|
dtypes.float32, shape=[100, 10], name="weights")
|
|
biases = array_ops.placeholder(dtypes.float32, shape=[10], name="biases")
|
|
activations = nn_ops.relu(
|
|
math_ops.matmul(inputs, weights) + biases, name="activations")
|
|
loss = math_ops.reduce_mean(activations, name="loss")
|
|
gdef = g.as_graph_def()
|
|
|
|
with ops.Graph().as_default() as g:
|
|
input_placeholder = array_ops.placeholder(dtypes.float32, shape=[32, 100])
|
|
weights_var = variables.Variable(
|
|
random_ops.truncated_normal([100, 10]), name="weights")
|
|
biases_var = variables.Variable(array_ops.zeros([10]), name="biases")
|
|
activations, loss = importer.import_graph_def(
|
|
gdef,
|
|
input_map={
|
|
"input:0": input_placeholder,
|
|
"weights:0": weights_var,
|
|
"biases:0": biases_var
|
|
},
|
|
return_elements=["activations:0", "loss:0"])
|
|
self.assertEqual([32, 10], activations.get_shape())
|
|
self.assertEqual([], loss.get_shape())
|
|
weights_grad, biases_grad = gradients_impl.gradients(
|
|
loss, [weights_var, biases_var])
|
|
self.assertEqual([100, 10], weights_grad.get_shape())
|
|
self.assertEqual([10], biases_grad.get_shape())
|
|
|
|
def testLargeGraph(self):
|
|
with self.cached_session():
|
|
# The default message byte limit is 64M. Ours is 2G with a warning at 512.
|
|
# Adding a 130M entries float32 tensor should exceed the warning, but not
|
|
# the hard limit.
|
|
input_shape = [130, 1000, 1000]
|
|
tensor_input = np.ones(input_shape, dtype=np.float32)
|
|
t = constant_op.constant(tensor_input, shape=input_shape)
|
|
g = array_ops.identity(t)
|
|
self.evaluate(g)
|
|
|
|
def testVersion(self):
|
|
v0 = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
|
|
v2 = versions.GRAPH_DEF_VERSION
|
|
v1 = (v0 + v2) // 2
|
|
for producer in v0, v1, v2:
|
|
for min_consumer in v0, v1, v2:
|
|
with ops.Graph().as_default():
|
|
a, = importer.import_graph_def(
|
|
self._MakeGraphDef(
|
|
"node { name: 'A' op: 'TwoIntOutputs' }",
|
|
producer=producer,
|
|
min_consumer=min_consumer),
|
|
return_elements=["A"])
|
|
self.assertEqual(a.graph.graph_def_versions.producer, producer)
|
|
self.assertEqual(a.graph.graph_def_versions.min_consumer,
|
|
min_consumer)
|
|
|
|
def testVersionLow(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
Exception,
|
|
r"GraphDef producer version -1 below min producer %d supported "
|
|
r"by TensorFlow \S+\. Please regenerate your graph.$" %
|
|
versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
|
|
importer.import_graph_def(self._MakeGraphDef("", producer=-1))
|
|
|
|
def testVersionHigh(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"GraphDef min consumer version %d above current version %d "
|
|
r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" %
|
|
(1 << 30, versions.GRAPH_DEF_VERSION)):
|
|
importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
|
|
|
|
def testVersionAppliesToOpConstruction(self):
|
|
"""These tests rely on shape fns in test_ops.cc."""
|
|
with ops.Graph().as_default():
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef(
|
|
"node { name: 'A' op: 'RequiresOlderGraphVersion' }",
|
|
producer=versions.GRAPH_DEF_VERSION - 1),
|
|
return_elements=["A"])
|
|
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesWithPredicateMatch(ValueError,
|
|
"Wrong graph version.*"):
|
|
importer.import_graph_def(
|
|
self._MakeGraphDef(
|
|
"node { name: 'A' op: 'RequiresOlderGraphVersion' }",
|
|
producer=versions.GRAPH_DEF_VERSION),
|
|
return_elements=["A"])
|
|
|
|
def testDefaultAttrsAdded(self):
|
|
with ops.Graph().as_default():
|
|
a = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'OpWithDefaultAttr' }
|
|
"""),
|
|
return_elements=["A"])
|
|
self.assertEqual(123.0, a[0].get_attr("default_float"))
|
|
|
|
def testDefaultAttrsRemoved(self):
|
|
producer_op_list = op_def_pb2.OpList()
|
|
text_format.Merge("""
|
|
op {
|
|
name: 'OpWithFutureDefaultAttr'
|
|
attr { name: 'default_int' type: 'int' default_value { i: 456 } }
|
|
}
|
|
""", producer_op_list)
|
|
# Attr only in producer_op_list with default value gets removed.
|
|
with ops.Graph().as_default():
|
|
a = importer.import_graph_def(
|
|
self._MakeGraphDef("""
|
|
node { name: 'A' op: 'OpWithFutureDefaultAttr'
|
|
attr { key: 'default_int' value { i: 456 } } }
|
|
"""),
|
|
return_elements=["A"],
|
|
producer_op_list=producer_op_list)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Operation 'import/A' has no attr named 'default_int'."):
|
|
a[0].get_attr("default_int")
|
|
|
|
def testFunctions(self):
|
|
dtype = dtypes.float32
|
|
|
|
@function.Defun(dtype, dtype, dtype, dtype)
|
|
def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument
|
|
# Return the inputs for simplicity of testing. The correct return value
|
|
# would be (dout1 + dout2, dout1 - dout2)
|
|
return x, y
|
|
|
|
@function.Defun(dtype, dtype, grad_func=Grad)
|
|
def FuncWithGrad(x, y):
|
|
return x + y, x - y
|
|
|
|
@function.Defun(dtypes.int32)
|
|
def ExternalTensorFunc(x):
|
|
# c must be defined in the containing graph
|
|
return x + c
|
|
|
|
@function.Defun(dtypes.int32, dtypes.int32)
|
|
def OuterFunc(x, y):
|
|
|
|
@function.Defun(dtypes.int32)
|
|
def InnerFunc(x):
|
|
return x + x
|
|
|
|
return InnerFunc(x) + y
|
|
|
|
# Create graph with function calls and export to GraphDef
|
|
with ops.Graph().as_default() as g1:
|
|
p1 = array_ops.placeholder(dtype, name="p1")
|
|
p2 = array_ops.placeholder(dtype, name="p2")
|
|
# pylint: disable=unexpected-keyword-arg
|
|
a, b = FuncWithGrad(p1, p2, name="f")
|
|
|
|
c = constant_op.constant(10, dtype=dtypes.int32)
|
|
ExternalTensorFunc(1, name="external")
|
|
|
|
OuterFunc(10, 1, name="outer")
|
|
# pylint: enable=unexpected-keyword-arg
|
|
|
|
gdef = g1.as_graph_def()
|
|
|
|
# Import GraphDef into new graph, add imported gradients, and test that
|
|
# imported functions can be run
|
|
with ops.Graph().as_default() as g2:
|
|
p1, p2, a, b = importer.import_graph_def(
|
|
gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
|
|
grad = gradients_impl.gradients([a], [p1, p2])
|
|
|
|
with self.session(graph=g2) as sess:
|
|
feed_dict = {p1: 1, p2: 2}
|
|
a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict)
|
|
self.assertEqual(a_val, 3.0)
|
|
self.assertEqual(b_val, -1.0)
|
|
# Grad function returns inputs values for testing
|
|
self.assertEqual(grad_val, [1.0, 2.0])
|
|
self.assertEqual(sess.run("external:0"), 11)
|
|
self.assertEqual(sess.run("outer:0"), 21)
|
|
|
|
# Export the new graph and reimport to test that imported functions can be
|
|
# successfully exported/imported again
|
|
gdef = g2.as_graph_def()
|
|
with ops.Graph().as_default() as g3:
|
|
p1, p2, a, b = importer.import_graph_def(
|
|
gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
|
|
# Create new gradient functions (in additional to the imported gradient
|
|
# functions created in g2).
|
|
grad = gradients_impl.gradients([a], [p1, p2])
|
|
|
|
with self.session(graph=g3) as sess:
|
|
feed_dict = {p1: 1, p2: 2}
|
|
a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict)
|
|
self.assertEqual(a_val, 3.0)
|
|
self.assertEqual(b_val, -1.0)
|
|
self.assertEqual(grad_val, [1.0, 2.0])
|
|
self.assertEqual(sess.run("external:0"), 11)
|
|
self.assertEqual(sess.run("outer:0"), 21)
|
|
|
|
@test_util.run_v1_only("import inside defun not supported when eager "
|
|
"execution is enabled.")
|
|
def testImportInsideDefun(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
|
|
@function.Defun()
|
|
def Add2(x, y):
|
|
return math_ops.add(x, y)
|
|
|
|
x = constant_op.constant(3.0, dtype=dtypes.float32)
|
|
y = constant_op.constant(-5.0, dtype=dtypes.float32)
|
|
z = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg
|
|
|
|
gdef = g.as_graph_def()
|
|
|
|
@function.Defun()
|
|
def TestFunc():
|
|
return importer.import_graph_def(gdef, return_elements=["z:0"])[0]
|
|
|
|
z = TestFunc()
|
|
|
|
with self.cached_session():
|
|
z_val = self.evaluate(z)
|
|
self.assertEqual(z_val, -2.0)
|
|
|
|
@test_util.run_v1_only("_as_tf_output not supported when eager execution "
|
|
"is enabled.")
|
|
def testImportGraphWithFunctionTwice(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
|
|
@function.Defun()
|
|
def Add2(x, y):
|
|
return math_ops.add(x, y)
|
|
|
|
x = array_ops.placeholder(dtype=dtypes.float32, name="x")
|
|
y = array_ops.placeholder(dtype=dtypes.float32, name="y")
|
|
_ = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg
|
|
|
|
gdef = g.as_graph_def()
|
|
|
|
x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
|
|
y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
|
|
input_map = {"x:0": x, "y:0": y}
|
|
|
|
with ops.name_scope("first"):
|
|
z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
|
|
input_map=input_map)[0]
|
|
|
|
with ops.name_scope("second"):
|
|
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
|
|
input_map=input_map)[0]
|
|
|
|
with self.cached_session() as sess:
|
|
z1_val, z2_val = sess.run((z1, z2))
|
|
self.assertAllEqual(z1_val, z2_val)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|