Support limited forward compatibility when importing a MetaGraphDef.

Change: 121880284
This commit is contained in:
Josh Levenberg 2016-05-09 13:53:57 -08:00 committed by TensorFlower Gardener
parent bf7e5fe193
commit 6e02bf0299
3 changed files with 155 additions and 85 deletions

View File

@ -23,8 +23,8 @@ import contextlib
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
@ -142,8 +142,15 @@ def _MaybeDevice(device):
yield
def _FindAttrInOpDef(attr_name, op_def):
for attr_def in op_def.attr:
if attr_name == attr_def.name:
return attr_def
return None
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None):
name=None, op_dict=None, producer_op_list=None):
"""Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
This function provides a way to import a serialized TensorFlow
@ -167,6 +174,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
Must contain an `OpDef` proto for each op type named in `graph_def`.
If omitted, uses the `OpDef` protos registered in the global registry.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided, attrs
for ops in `graph_def` that are not in `op_dict` that have their default
value according to `producer_op_list` will be removed. This will allow
some more `GraphDef`s produced by later binaries to be accepted by
earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
@ -213,6 +226,11 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
if op_dict is None:
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is None:
producer_op_dict = None
else:
producer_op_dict = {op.name: op for op in producer_op_list.op}
with ops.op_scope(input_map.values(), name, 'import'):
g = ops.get_default_graph()
g.graph_def_versions.CopyFrom(graph_def.versions)
@ -233,6 +251,21 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
value = node.attr[key]
if value is None or value.WhichOneof('value') is None:
node.attr[key].CopyFrom(attr_def.default_value)
if producer_op_dict:
# Remove any default attr values that aren't in op_def.
if node.op in producer_op_dict:
producer_op_def = producer_op_dict[node.op]
# We make a copy of node.attr to iterate through since we
# may modify node.attr inside the loop.
for key in list(node.attr):
if _FindAttrInOpDef(key, op_def) is None:
# No attr_def in consumer, look in producer.
attr_def = _FindAttrInOpDef(key, producer_op_def)
if (attr_def and attr_def.HasField('default_value') and
node.attr[key] == attr_def.default_value):
# Unknown attr had default value in producer, delete it
# so it can be understood by consumer.
del node.attr[key]
output_types = _OutputTypes(node, op_dict)
name_to_op[node.name] = g.create_op(
@ -326,8 +359,8 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
_InvalidNodeMessage(
node,
'Input types mismatch (expected %r but got %r)'
% (", ".join(dtypes.as_dtype(x).name for x in input_types),
", ".join(x.name for x in op._input_dtypes))))
% (', '.join(dtypes.as_dtype(x).name for x in input_types),
', '.join(x.name for x in op._input_dtypes))))
# pylint: enable=protected_access
# Execute shape inference for this op.

View File

@ -106,12 +106,16 @@ text_format.Merge("""
output_arg { name: 'a' type: DT_INT32 }
attr { name: 'default_float' type: 'float' default_value { f: 123.0 } }
}
op {
name: 'OpWithFutureDefaultAttr'
}
""", _op_list)
op_def_registry.register_op_list(_op_list)
# NOTE(mrry): Dummy shape registrations for ops used in the tests.
for op_def in _op_list.op:
tf.RegisterShape(op_def.name)(None)
class ImportGraphDefTest(tf.test.TestCase):
def _MakeGraphDef(self, text, producer=tf.GRAPH_DEF_VERSION,
@ -139,8 +143,8 @@ class ImportGraphDefTest(tf.test.TestCase):
attr { key: 'T' value { type: DT_FLOAT } }
input: 'A:1' input: 'B:1' }
"""),
return_elements=['A', 'B', 'C', 'D'],
name='import')
return_elements=["A", "B", "C", "D"],
name="import")
# Assert that the import process creates distinct tensors.
self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
@ -157,20 +161,20 @@ class ImportGraphDefTest(tf.test.TestCase):
self.assertEqual(d.inputs[1], b.outputs[1])
# Check the types of the returned ops and tensors.
self.assertEqual(a.type, 'Oif')
self.assertEqual(b.type, 'Otl')
self.assertEqual(c.type, 'In')
self.assertEqual(d.type, 'In')
self.assertEqual(a.type, "Oif")
self.assertEqual(b.type, "Otl")
self.assertEqual(c.type, "In")
self.assertEqual(d.type, "In")
self.assertEqual(a.outputs[0].dtype, tf.int32)
self.assertEqual(a.outputs[1].dtype, tf.float32)
self.assertEqual(b.outputs[0].dtype, tf.int32)
self.assertEqual(b.outputs[1].dtype, tf.float32)
# Check the names of the returned ops.
self.assertEqual(a.name, 'import/A')
self.assertEqual(b.name, 'import/B')
self.assertEqual(c.name, 'import/C')
self.assertEqual(d.name, 'import/D')
self.assertEqual(a.name, "import/A")
self.assertEqual(b.name, "import/B")
self.assertEqual(c.name, "import/C")
self.assertEqual(d.name, "import/D")
# Check that the op_def is still available.
self.assertNotEqual(None, a.op_def)
@ -193,8 +197,8 @@ class ImportGraphDefTest(tf.test.TestCase):
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:1' input: 'B:1' }
"""),
input_map={'A:0': feed_a_0, 'B:1': feed_b_1},
return_elements=['A', 'B', 'C', 'D'])
input_map={"A:0": feed_a_0, "B:1": feed_b_1},
return_elements=["A", "B", "C", "D"])
self.assertEqual(c.inputs[0], feed_a_0)
self.assertEqual(c.inputs[1], b.outputs[0])
@ -219,8 +223,8 @@ class ImportGraphDefTest(tf.test.TestCase):
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:1' input: 'B:1' }
"""),
input_map={b'A:0': feed_a_0, b'B:1': feed_b_1},
return_elements=[b'A', b'B', b'C', b'D'])
input_map={b"A:0": feed_a_0, b"B:1": feed_b_1},
return_elements=[b"A", b"B", b"C", b"D"])
self.assertEqual(c.inputs[0], feed_a_0)
self.assertEqual(c.inputs[1], b.outputs[0])
@ -245,8 +249,8 @@ class ImportGraphDefTest(tf.test.TestCase):
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:1' input: 'B:1' }
"""),
input_map={u'A:0': feed_a_0, u'B:1': feed_b_1},
return_elements=[u'A', u'B', u'C', u'D'])
input_map={u"A:0": feed_a_0, u"B:1": feed_b_1},
return_elements=[u"A", u"B", u"C", u"D"])
self.assertEqual(c.inputs[0], feed_a_0)
self.assertEqual(c.inputs[1], b.outputs[0])
@ -260,7 +264,7 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'A' op: 'Oii' }
node { name: 'B' op: 'Ii' input: 'A' }
"""),
return_elements=['A', 'B'])
return_elements=["A", "B"])
self.assertEqual(b.inputs[0], a.outputs[0])
@ -272,8 +276,8 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'A' op: 'Oii' }
node { name: 'B' op: 'Ii' input: 'A:0' }
"""),
input_map={'A': feed_a_0},
return_elements=['B'])
input_map={"A": feed_a_0},
return_elements=["B"])
self.assertEqual(b.inputs[0], feed_a_0)
@ -284,7 +288,7 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'A' op: 'None' }
node { name: 'B' op: 'None' input: '^A' }
"""),
return_elements=['A', 'B'])
return_elements=["A", "B"])
self.assertEqual(b.control_inputs, [a])
@ -297,7 +301,7 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' }
node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' }
"""),
return_elements=['A', 'B', 'C', 'D'])
return_elements=["A", "B", "C", "D"])
self.assertEqual(c.inputs[0], a.outputs[0])
self.assertEqual(c.inputs[1], b.outputs[0])
@ -320,7 +324,7 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'B' op: 'Unary'
attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' }
"""),
return_elements=['A', 'B'])
return_elements=["A", "B"])
self.assertEqual(a.inputs[0], b.outputs[0])
self.assertEqual(b.inputs[0], a.outputs[0])
@ -334,7 +338,7 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'B' op: 'If' input: 'A:0' }
"""))
self.assertTrue(
'Cannot convert a tensor of type int32 to an input of type float' in
"Cannot convert a tensor of type int32 to an input of type float" in
str(e.exception))
def testInvalidSignatureTooManyInputsInGraphDef(self):
@ -345,7 +349,7 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'A' op: 'Oi' }
node { name: 'B' op: 'None' input: 'A:0' }
"""))
self.assertTrue('More inputs specified (\'A:0\') than the op expects' in
self.assertTrue("More inputs specified ('A:0') than the op expects" in
str(e.exception))
def testInvalidSignatureNotEnoughInputsInGraphDef(self):
@ -356,8 +360,8 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'A' op: 'Oi' }
node { name: 'B' op: 'Iif' input: 'A:0' }
"""))
self.assertTrue('Input types mismatch (expected \'int32, float32\' but '
'got \'int32\')' in str(e.exception))
self.assertTrue("Input types mismatch (expected 'int32, float32' but "
"got 'int32')" in str(e.exception))
def testMissingInputOpInGraphDef(self):
with tf.Graph().as_default():
@ -375,8 +379,8 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'B' op: 'If' input: 'A:0' }
"""),
input_map={'A:0': feed_a_0},
return_elements=['B'])
input_map={"A:0": feed_a_0},
return_elements=["B"])
self.assertEqual(b.inputs[0], feed_a_0)
def testMissingInputTensorInGraphDef(self):
@ -425,7 +429,7 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
"""),
return_elements=['B'])
return_elements=["B"])
self.assertTrue("return_element 'B' not found in graph_def." in
str(e.exception))
@ -436,7 +440,7 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'Oi' }
"""),
return_elements=['A:1'])
return_elements=["A:1"])
self.assertTrue("return_element 'A:1' not found in graph_def." in
str(e.exception))
@ -445,7 +449,7 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'Oi' }
"""),
return_elements=['B:0'])
return_elements=["B:0"])
self.assertTrue("return_element 'B:0' not found in graph_def." in
str(e.exception))
@ -454,7 +458,7 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'Oi' }
"""),
return_elements=['A:B:0'])
return_elements=["A:B:0"])
self.assertTrue("return_element 'A:B:0' not found in graph_def." in
str(e.exception))
@ -465,8 +469,8 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
"""),
input_map={'B:0': tf.constant(5.0)})
self.assertTrue('not found in graph_def: [B:0]' in str(e.exception))
input_map={"B:0": tf.constant(5.0)})
self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
def testInputMapTypeMismatch(self):
with tf.Graph().as_default():
@ -476,9 +480,9 @@ class ImportGraphDefTest(tf.test.TestCase):
node { name: 'A' op: 'Oi' }
node { name: 'B' op: 'Ii' input: 'A:0' }
"""),
input_map={'A:0': tf.constant(5.0)})
input_map={"A:0": tf.constant(5.0)})
self.assertTrue(
'Cannot convert a tensor of type float32 to an input of type int32.'
"Cannot convert a tensor of type float32 to an input of type int32."
in str(e.exception))
def testNoReturns(self):
@ -489,8 +493,8 @@ class ImportGraphDefTest(tf.test.TestCase):
"""))
self.assertEqual(ret, None)
a = g.get_operation_by_name('import/A')
self.assertEqual(a.type, 'None')
a = g.get_operation_by_name("import/A")
self.assertEqual(a.type, "None")
def testOverrideNamePrefix(self):
with tf.Graph().as_default():
@ -498,8 +502,8 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
"""),
return_elements=['A'], name='imported_graph')
self.assertEqual(a.name, 'imported_graph/A')
return_elements=["A"], name="imported_graph")
self.assertEqual(a.name, "imported_graph/A")
def testNamePrefixColocationAttrs(self):
original_graph_def = self._MakeGraphDef("""
@ -511,7 +515,7 @@ class ImportGraphDefTest(tf.test.TestCase):
with tf.Graph().as_default():
b, = tf.import_graph_def(original_graph_def,
return_elements=['B'], name='imported_graph')
return_elements=["B"], name="imported_graph")
self.assertProtoEqualsVersion("""
node { name: 'imported_graph/A' op: 'None' }
node { name: 'imported_graph/B' op: 'None' attr {
@ -529,9 +533,9 @@ class ImportGraphDefTest(tf.test.TestCase):
with tf.Graph().as_default():
b, = tf.import_graph_def(original_graph_def,
return_elements=['B'], name='')
return_elements=["B"], name="")
_, = tf.import_graph_def(original_graph_def,
return_elements=['B'], name='')
return_elements=["B"], name="")
self.assertProtoEqualsVersion("""
node { name: 'A' op: 'None' }
node { name: 'B' op: 'None' attr {
@ -551,90 +555,90 @@ class ImportGraphDefTest(tf.test.TestCase):
value { list { s: 'loc:@A' } }
} }""")
with tf.Graph().as_default():
with self.assertRaisesRegexp(ValueError, 'does not exist during import'):
with self.assertRaisesRegexp(ValueError, "does not exist during import"):
tf.import_graph_def(original_graph_def,
return_elements=['B'], name='imported_graph')
return_elements=["B"], name="imported_graph")
def testEmptyGraph(self):
with tf.Graph().as_default() as g:
init_version = g.version
tf.import_graph_def(self._MakeGraphDef(''))
tf.import_graph_def(self._MakeGraphDef(""))
self.assertEqual(init_version, g.version)
def testInvalidInputForGraphDef(self):
with tf.Graph().as_default():
with self.assertRaises(TypeError) as e:
tf.import_graph_def('')
tf.import_graph_def("")
self.assertEqual(
'graph_def must be a GraphDef proto.', str(e.exception))
"graph_def must be a GraphDef proto.", str(e.exception))
def testInvalidInputForInputMap(self):
with tf.Graph().as_default():
with self.assertRaises(TypeError) as e:
tf.import_graph_def(self._MakeGraphDef(''),
input_map=[tf.constant(5.0)])
self.assertEqual('input_map must be a dictionary mapping strings to '
'Tensor objects.', str(e.exception))
tf.import_graph_def(self._MakeGraphDef(""),
input_map=[tf.constant(5.0)])
self.assertEqual("input_map must be a dictionary mapping strings to "
"Tensor objects.", str(e.exception))
def testInvalidInputForReturnOperations(self):
with tf.Graph().as_default():
with self.assertRaises(TypeError) as e:
tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7])
tf.import_graph_def(self._MakeGraphDef(""), return_elements=[7])
self.assertEqual(
'return_elements must be a list of strings.', str(e.exception))
"return_elements must be a list of strings.", str(e.exception))
def testWithExtensionAndAttr(self):
with tf.Graph().as_default() as g:
c = tf.constant(5.0, dtype=tf.float32, name='c')
tf.pack([c, c], name='pack')
c = tf.constant(5.0, dtype=tf.float32, name="c")
tf.pack([c, c], name="pack")
gdef = g.as_graph_def()
with self.test_session():
pack, = tf.import_graph_def(gdef, return_elements=['pack'])
pack, = tf.import_graph_def(gdef, return_elements=["pack"])
self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
def testWithDevice(self):
with tf.Graph().as_default() as g:
# No device.
a = tf.constant(3.0, name='a')
a = tf.constant(3.0, name="a")
with tf.device('/cpu:0'):
b = tf.constant(4.0, name='b')
with tf.device('/job:worker'):
c = tf.constant(5.0, name='c')
with tf.device("/cpu:0"):
b = tf.constant(4.0, name="b")
with tf.device("/job:worker"):
c = tf.constant(5.0, name="c")
gdef = g.as_graph_def()
with tf.Graph().as_default():
a2, b2, c2 = tf.import_graph_def(
gdef, return_elements=['a', 'b', 'c'])
gdef, return_elements=["a", "b", "c"])
self.assertEqual(a.device, a2.device)
self.assertEqual(b.device, b2.device)
self.assertEqual(c.device, c2.device)
with tf.Graph().as_default():
with tf.device(device.merge_device('/task:0')):
with tf.device(device.merge_device("/task:0")):
a3, b3, c3 = tf.import_graph_def(
gdef, return_elements=['a', 'b', 'c'])
self.assertEqual('/task:0', a3.device)
self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized.
self.assertEqual(c.device + '/task:0', c3.device)
gdef, return_elements=["a", "b", "c"])
self.assertEqual("/task:0", a3.device)
self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized.
self.assertEqual(c.device + "/task:0", c3.device)
with tf.Graph().as_default():
with tf.device(device.merge_device('/job:ps')):
with tf.device(device.merge_device("/job:ps")):
a4, b4, c4 = tf.import_graph_def(
gdef, return_elements=['a', 'b', 'c'])
self.assertEqual('/job:ps', a4.device)
self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized.
gdef, return_elements=["a", "b", "c"])
self.assertEqual("/job:ps", a4.device)
self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized.
self.assertEqual(c.device, c4.device) # worker overrides ps.
with tf.Graph().as_default():
with tf.device(device.merge_device('/gpu:0')):
with tf.device(device.merge_device("/gpu:0")):
a5, b5, c5 = tf.import_graph_def(
gdef, return_elements=['a', 'b', 'c'])
self.assertEqual('/device:GPU:0', a5.device)
self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu.
self.assertEqual(c.device + '/device:GPU:0', c5.device)
gdef, return_elements=["a", "b", "c"])
self.assertEqual("/device:GPU:0", a5.device)
self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu.
self.assertEqual(c.device + "/device:GPU:0", c5.device)
def testWithDeviceFunctionDependingOnInputs(self):
with tf.Graph().as_default() as g:
@ -706,7 +710,7 @@ class ImportGraphDefTest(tf.test.TestCase):
a, = tf.import_graph_def(
self._MakeGraphDef("node { name: 'A' op: 'Oii' }",
producer=producer, min_consumer=min_consumer),
return_elements=['A'])
return_elements=["A"])
self.assertEqual(a.graph.graph_def_versions.producer, producer)
self.assertEqual(a.graph.graph_def_versions.min_consumer,
min_consumer)
@ -739,9 +743,38 @@ class ImportGraphDefTest(tf.test.TestCase):
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithDefaultAttr' }
"""),
return_elements=['A'])
return_elements=["A"])
self.assertEqual(123.0, a[0].get_attr("default_float"))
def testDefaultAttrsRemoved(self):
producer_op_list = op_def_pb2.OpList()
text_format.Merge("""
op {
name: 'OpWithFutureDefaultAttr'
attr { name: 'default_int' type: 'int' default_value { i: 456 } }
}
""", producer_op_list)
# Attr only in producer_op_list with default value gets removed.
with tf.Graph().as_default():
a = tf.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithFutureDefaultAttr'
attr { key: 'default_int' value { i: 456 } } }
"""),
return_elements=["A"], producer_op_list=producer_op_list)
with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
a[0].get_attr("default_int")
if __name__ == '__main__':
# Attr only in producer_op_list with non-default value is preserved.
with tf.Graph().as_default():
a = tf.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithFutureDefaultAttr'
attr { key: 'default_int' value { i: 987 } } }
"""),
return_elements=["A"], producer_op_list=producer_op_list)
self.assertEqual(987, a[0].get_attr("default_int"))
if __name__ == "__main__":
tf.test.main()

View File

@ -1317,7 +1317,11 @@ def _import_meta_graph_def(meta_graph_def):
(i.e., no variables to restore).
"""
# Gathers the list of nodes we are interested in.
importer.import_graph_def(meta_graph_def.graph_def, name="")
producer_op_list = None
if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
importer.import_graph_def(meta_graph_def.graph_def, name="",
producer_op_list=producer_op_list)
# Restores all the other collections.
for key, col_def in meta_graph_def.collection_def.items():