From 6e02bf0299687b74557ff31b641b8d6d51cba4bf Mon Sep 17 00:00:00 2001 From: Josh Levenberg Date: Mon, 9 May 2016 13:53:57 -0800 Subject: [PATCH] Support limited forward compatibility when importing a MetaGraphDef. Change: 121880284 --- tensorflow/python/framework/importer.py | 41 +++- tensorflow/python/framework/importer_test.py | 193 +++++++++++-------- tensorflow/python/training/saver.py | 6 +- 3 files changed, 155 insertions(+), 85 deletions(-) diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 3d8b4e460ee..afb3b55d839 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -23,8 +23,8 @@ import contextlib from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 -from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import dtypes +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat @@ -142,8 +142,15 @@ def _MaybeDevice(device): yield +def _FindAttrInOpDef(attr_name, op_def): + for attr_def in op_def.attr: + if attr_name == attr_def.name: + return attr_def + return None + + def import_graph_def(graph_def, input_map=None, return_elements=None, - name=None, op_dict=None): + name=None, op_dict=None, producer_op_list=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow @@ -167,6 +174,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. + producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) + list of `OpDef`s used by the producer of the graph. If provided, attrs + for ops in `graph_def` that are not in `op_dict` that have their default + value according to `producer_op_list` will be removed. This will allow + some more `GraphDef`s produced by later binaries to be accepted by + earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, @@ -213,6 +226,11 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, if op_dict is None: op_dict = op_def_registry.get_registered_ops() + if producer_op_list is None: + producer_op_dict = None + else: + producer_op_dict = {op.name: op for op in producer_op_list.op} + with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() g.graph_def_versions.CopyFrom(graph_def.versions) @@ -233,6 +251,21 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) + if producer_op_dict: + # Remove any default attr values that aren't in op_def. + if node.op in producer_op_dict: + producer_op_def = producer_op_dict[node.op] + # We make a copy of node.attr to iterate through since we + # may modify node.attr inside the loop. + for key in list(node.attr): + if _FindAttrInOpDef(key, op_def) is None: + # No attr_def in consumer, look in producer. + attr_def = _FindAttrInOpDef(key, producer_op_def) + if (attr_def and attr_def.HasField('default_value') and + node.attr[key] == attr_def.default_value): + # Unknown attr had default value in producer, delete it + # so it can be understood by consumer. + del node.attr[key] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( @@ -326,8 +359,8 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' - % (", ".join(dtypes.as_dtype(x).name for x in input_types), - ", ".join(x.name for x in op._input_dtypes)))) + % (', '.join(dtypes.as_dtype(x).name for x in input_types), + ', '.join(x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index c6fd1edc026..7abb683839d 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -106,12 +106,16 @@ text_format.Merge(""" output_arg { name: 'a' type: DT_INT32 } attr { name: 'default_float' type: 'float' default_value { f: 123.0 } } } + op { + name: 'OpWithFutureDefaultAttr' + } """, _op_list) op_def_registry.register_op_list(_op_list) # NOTE(mrry): Dummy shape registrations for ops used in the tests. for op_def in _op_list.op: tf.RegisterShape(op_def.name)(None) + class ImportGraphDefTest(tf.test.TestCase): def _MakeGraphDef(self, text, producer=tf.GRAPH_DEF_VERSION, @@ -139,8 +143,8 @@ class ImportGraphDefTest(tf.test.TestCase): attr { key: 'T' value { type: DT_FLOAT } } input: 'A:1' input: 'B:1' } """), - return_elements=['A', 'B', 'C', 'D'], - name='import') + return_elements=["A", "B", "C", "D"], + name="import") # Assert that the import process creates distinct tensors. self.assertNotEqual(a.outputs[0].name, a.outputs[1].name) @@ -157,20 +161,20 @@ class ImportGraphDefTest(tf.test.TestCase): self.assertEqual(d.inputs[1], b.outputs[1]) # Check the types of the returned ops and tensors. - self.assertEqual(a.type, 'Oif') - self.assertEqual(b.type, 'Otl') - self.assertEqual(c.type, 'In') - self.assertEqual(d.type, 'In') + self.assertEqual(a.type, "Oif") + self.assertEqual(b.type, "Otl") + self.assertEqual(c.type, "In") + self.assertEqual(d.type, "In") self.assertEqual(a.outputs[0].dtype, tf.int32) self.assertEqual(a.outputs[1].dtype, tf.float32) self.assertEqual(b.outputs[0].dtype, tf.int32) self.assertEqual(b.outputs[1].dtype, tf.float32) # Check the names of the returned ops. - self.assertEqual(a.name, 'import/A') - self.assertEqual(b.name, 'import/B') - self.assertEqual(c.name, 'import/C') - self.assertEqual(d.name, 'import/D') + self.assertEqual(a.name, "import/A") + self.assertEqual(b.name, "import/B") + self.assertEqual(c.name, "import/C") + self.assertEqual(d.name, "import/D") # Check that the op_def is still available. self.assertNotEqual(None, a.op_def) @@ -193,8 +197,8 @@ class ImportGraphDefTest(tf.test.TestCase): attr { key: 'T' value { type: DT_INT32 } } input: 'A:1' input: 'B:1' } """), - input_map={'A:0': feed_a_0, 'B:1': feed_b_1}, - return_elements=['A', 'B', 'C', 'D']) + input_map={"A:0": feed_a_0, "B:1": feed_b_1}, + return_elements=["A", "B", "C", "D"]) self.assertEqual(c.inputs[0], feed_a_0) self.assertEqual(c.inputs[1], b.outputs[0]) @@ -219,8 +223,8 @@ class ImportGraphDefTest(tf.test.TestCase): attr { key: 'T' value { type: DT_INT32 } } input: 'A:1' input: 'B:1' } """), - input_map={b'A:0': feed_a_0, b'B:1': feed_b_1}, - return_elements=[b'A', b'B', b'C', b'D']) + input_map={b"A:0": feed_a_0, b"B:1": feed_b_1}, + return_elements=[b"A", b"B", b"C", b"D"]) self.assertEqual(c.inputs[0], feed_a_0) self.assertEqual(c.inputs[1], b.outputs[0]) @@ -245,8 +249,8 @@ class ImportGraphDefTest(tf.test.TestCase): attr { key: 'T' value { type: DT_INT32 } } input: 'A:1' input: 'B:1' } """), - input_map={u'A:0': feed_a_0, u'B:1': feed_b_1}, - return_elements=[u'A', u'B', u'C', u'D']) + input_map={u"A:0": feed_a_0, u"B:1": feed_b_1}, + return_elements=[u"A", u"B", u"C", u"D"]) self.assertEqual(c.inputs[0], feed_a_0) self.assertEqual(c.inputs[1], b.outputs[0]) @@ -260,7 +264,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oii' } node { name: 'B' op: 'Ii' input: 'A' } """), - return_elements=['A', 'B']) + return_elements=["A", "B"]) self.assertEqual(b.inputs[0], a.outputs[0]) @@ -272,8 +276,8 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oii' } node { name: 'B' op: 'Ii' input: 'A:0' } """), - input_map={'A': feed_a_0}, - return_elements=['B']) + input_map={"A": feed_a_0}, + return_elements=["B"]) self.assertEqual(b.inputs[0], feed_a_0) @@ -284,7 +288,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'None' } node { name: 'B' op: 'None' input: '^A' } """), - return_elements=['A', 'B']) + return_elements=["A", "B"]) self.assertEqual(b.control_inputs, [a]) @@ -297,7 +301,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' } node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' } """), - return_elements=['A', 'B', 'C', 'D']) + return_elements=["A", "B", "C", "D"]) self.assertEqual(c.inputs[0], a.outputs[0]) self.assertEqual(c.inputs[1], b.outputs[0]) @@ -320,7 +324,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'B' op: 'Unary' attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' } """), - return_elements=['A', 'B']) + return_elements=["A", "B"]) self.assertEqual(a.inputs[0], b.outputs[0]) self.assertEqual(b.inputs[0], a.outputs[0]) @@ -334,7 +338,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'B' op: 'If' input: 'A:0' } """)) self.assertTrue( - 'Cannot convert a tensor of type int32 to an input of type float' in + "Cannot convert a tensor of type int32 to an input of type float" in str(e.exception)) def testInvalidSignatureTooManyInputsInGraphDef(self): @@ -345,7 +349,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } node { name: 'B' op: 'None' input: 'A:0' } """)) - self.assertTrue('More inputs specified (\'A:0\') than the op expects' in + self.assertTrue("More inputs specified ('A:0') than the op expects" in str(e.exception)) def testInvalidSignatureNotEnoughInputsInGraphDef(self): @@ -356,8 +360,8 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } node { name: 'B' op: 'Iif' input: 'A:0' } """)) - self.assertTrue('Input types mismatch (expected \'int32, float32\' but ' - 'got \'int32\')' in str(e.exception)) + self.assertTrue("Input types mismatch (expected 'int32, float32' but " + "got 'int32')" in str(e.exception)) def testMissingInputOpInGraphDef(self): with tf.Graph().as_default(): @@ -375,8 +379,8 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'B' op: 'If' input: 'A:0' } """), - input_map={'A:0': feed_a_0}, - return_elements=['B']) + input_map={"A:0": feed_a_0}, + return_elements=["B"]) self.assertEqual(b.inputs[0], feed_a_0) def testMissingInputTensorInGraphDef(self): @@ -425,7 +429,7 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), - return_elements=['B']) + return_elements=["B"]) self.assertTrue("return_element 'B' not found in graph_def." in str(e.exception)) @@ -436,7 +440,7 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'Oi' } """), - return_elements=['A:1']) + return_elements=["A:1"]) self.assertTrue("return_element 'A:1' not found in graph_def." in str(e.exception)) @@ -445,7 +449,7 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'Oi' } """), - return_elements=['B:0']) + return_elements=["B:0"]) self.assertTrue("return_element 'B:0' not found in graph_def." in str(e.exception)) @@ -454,7 +458,7 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'Oi' } """), - return_elements=['A:B:0']) + return_elements=["A:B:0"]) self.assertTrue("return_element 'A:B:0' not found in graph_def." in str(e.exception)) @@ -465,8 +469,8 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), - input_map={'B:0': tf.constant(5.0)}) - self.assertTrue('not found in graph_def: [B:0]' in str(e.exception)) + input_map={"B:0": tf.constant(5.0)}) + self.assertTrue("not found in graph_def: [B:0]" in str(e.exception)) def testInputMapTypeMismatch(self): with tf.Graph().as_default(): @@ -476,9 +480,9 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } node { name: 'B' op: 'Ii' input: 'A:0' } """), - input_map={'A:0': tf.constant(5.0)}) + input_map={"A:0": tf.constant(5.0)}) self.assertTrue( - 'Cannot convert a tensor of type float32 to an input of type int32.' + "Cannot convert a tensor of type float32 to an input of type int32." in str(e.exception)) def testNoReturns(self): @@ -489,8 +493,8 @@ class ImportGraphDefTest(tf.test.TestCase): """)) self.assertEqual(ret, None) - a = g.get_operation_by_name('import/A') - self.assertEqual(a.type, 'None') + a = g.get_operation_by_name("import/A") + self.assertEqual(a.type, "None") def testOverrideNamePrefix(self): with tf.Graph().as_default(): @@ -498,8 +502,8 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), - return_elements=['A'], name='imported_graph') - self.assertEqual(a.name, 'imported_graph/A') + return_elements=["A"], name="imported_graph") + self.assertEqual(a.name, "imported_graph/A") def testNamePrefixColocationAttrs(self): original_graph_def = self._MakeGraphDef(""" @@ -511,7 +515,7 @@ class ImportGraphDefTest(tf.test.TestCase): with tf.Graph().as_default(): b, = tf.import_graph_def(original_graph_def, - return_elements=['B'], name='imported_graph') + return_elements=["B"], name="imported_graph") self.assertProtoEqualsVersion(""" node { name: 'imported_graph/A' op: 'None' } node { name: 'imported_graph/B' op: 'None' attr { @@ -529,9 +533,9 @@ class ImportGraphDefTest(tf.test.TestCase): with tf.Graph().as_default(): b, = tf.import_graph_def(original_graph_def, - return_elements=['B'], name='') + return_elements=["B"], name="") _, = tf.import_graph_def(original_graph_def, - return_elements=['B'], name='') + return_elements=["B"], name="") self.assertProtoEqualsVersion(""" node { name: 'A' op: 'None' } node { name: 'B' op: 'None' attr { @@ -551,90 +555,90 @@ class ImportGraphDefTest(tf.test.TestCase): value { list { s: 'loc:@A' } } } }""") with tf.Graph().as_default(): - with self.assertRaisesRegexp(ValueError, 'does not exist during import'): + with self.assertRaisesRegexp(ValueError, "does not exist during import"): tf.import_graph_def(original_graph_def, - return_elements=['B'], name='imported_graph') + return_elements=["B"], name="imported_graph") def testEmptyGraph(self): with tf.Graph().as_default() as g: init_version = g.version - tf.import_graph_def(self._MakeGraphDef('')) + tf.import_graph_def(self._MakeGraphDef("")) self.assertEqual(init_version, g.version) def testInvalidInputForGraphDef(self): with tf.Graph().as_default(): with self.assertRaises(TypeError) as e: - tf.import_graph_def('') + tf.import_graph_def("") self.assertEqual( - 'graph_def must be a GraphDef proto.', str(e.exception)) + "graph_def must be a GraphDef proto.", str(e.exception)) def testInvalidInputForInputMap(self): with tf.Graph().as_default(): with self.assertRaises(TypeError) as e: - tf.import_graph_def(self._MakeGraphDef(''), - input_map=[tf.constant(5.0)]) - self.assertEqual('input_map must be a dictionary mapping strings to ' - 'Tensor objects.', str(e.exception)) + tf.import_graph_def(self._MakeGraphDef(""), + input_map=[tf.constant(5.0)]) + self.assertEqual("input_map must be a dictionary mapping strings to " + "Tensor objects.", str(e.exception)) def testInvalidInputForReturnOperations(self): with tf.Graph().as_default(): with self.assertRaises(TypeError) as e: - tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7]) + tf.import_graph_def(self._MakeGraphDef(""), return_elements=[7]) self.assertEqual( - 'return_elements must be a list of strings.', str(e.exception)) + "return_elements must be a list of strings.", str(e.exception)) def testWithExtensionAndAttr(self): with tf.Graph().as_default() as g: - c = tf.constant(5.0, dtype=tf.float32, name='c') - tf.pack([c, c], name='pack') + c = tf.constant(5.0, dtype=tf.float32, name="c") + tf.pack([c, c], name="pack") gdef = g.as_graph_def() with self.test_session(): - pack, = tf.import_graph_def(gdef, return_elements=['pack']) + pack, = tf.import_graph_def(gdef, return_elements=["pack"]) self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0]) def testWithDevice(self): with tf.Graph().as_default() as g: # No device. - a = tf.constant(3.0, name='a') + a = tf.constant(3.0, name="a") - with tf.device('/cpu:0'): - b = tf.constant(4.0, name='b') - with tf.device('/job:worker'): - c = tf.constant(5.0, name='c') + with tf.device("/cpu:0"): + b = tf.constant(4.0, name="b") + with tf.device("/job:worker"): + c = tf.constant(5.0, name="c") gdef = g.as_graph_def() with tf.Graph().as_default(): a2, b2, c2 = tf.import_graph_def( - gdef, return_elements=['a', 'b', 'c']) + gdef, return_elements=["a", "b", "c"]) self.assertEqual(a.device, a2.device) self.assertEqual(b.device, b2.device) self.assertEqual(c.device, c2.device) with tf.Graph().as_default(): - with tf.device(device.merge_device('/task:0')): + with tf.device(device.merge_device("/task:0")): a3, b3, c3 = tf.import_graph_def( - gdef, return_elements=['a', 'b', 'c']) - self.assertEqual('/task:0', a3.device) - self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized. - self.assertEqual(c.device + '/task:0', c3.device) + gdef, return_elements=["a", "b", "c"]) + self.assertEqual("/task:0", a3.device) + self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized. + self.assertEqual(c.device + "/task:0", c3.device) with tf.Graph().as_default(): - with tf.device(device.merge_device('/job:ps')): + with tf.device(device.merge_device("/job:ps")): a4, b4, c4 = tf.import_graph_def( - gdef, return_elements=['a', 'b', 'c']) - self.assertEqual('/job:ps', a4.device) - self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized. + gdef, return_elements=["a", "b", "c"]) + self.assertEqual("/job:ps", a4.device) + self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized. self.assertEqual(c.device, c4.device) # worker overrides ps. with tf.Graph().as_default(): - with tf.device(device.merge_device('/gpu:0')): + with tf.device(device.merge_device("/gpu:0")): a5, b5, c5 = tf.import_graph_def( - gdef, return_elements=['a', 'b', 'c']) - self.assertEqual('/device:GPU:0', a5.device) - self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu. - self.assertEqual(c.device + '/device:GPU:0', c5.device) + gdef, return_elements=["a", "b", "c"]) + self.assertEqual("/device:GPU:0", a5.device) + self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu. + self.assertEqual(c.device + "/device:GPU:0", c5.device) def testWithDeviceFunctionDependingOnInputs(self): with tf.Graph().as_default() as g: @@ -706,7 +710,7 @@ class ImportGraphDefTest(tf.test.TestCase): a, = tf.import_graph_def( self._MakeGraphDef("node { name: 'A' op: 'Oii' }", producer=producer, min_consumer=min_consumer), - return_elements=['A']) + return_elements=["A"]) self.assertEqual(a.graph.graph_def_versions.producer, producer) self.assertEqual(a.graph.graph_def_versions.min_consumer, min_consumer) @@ -739,9 +743,38 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'A' op: 'OpWithDefaultAttr' } """), - return_elements=['A']) + return_elements=["A"]) self.assertEqual(123.0, a[0].get_attr("default_float")) + def testDefaultAttrsRemoved(self): + producer_op_list = op_def_pb2.OpList() + text_format.Merge(""" + op { + name: 'OpWithFutureDefaultAttr' + attr { name: 'default_int' type: 'int' default_value { i: 456 } } + } + """, producer_op_list) + # Attr only in producer_op_list with default value gets removed. + with tf.Graph().as_default(): + a = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'OpWithFutureDefaultAttr' + attr { key: 'default_int' value { i: 456 } } } + """), + return_elements=["A"], producer_op_list=producer_op_list) + with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"): + a[0].get_attr("default_int") -if __name__ == '__main__': + # Attr only in producer_op_list with non-default value is preserved. + with tf.Graph().as_default(): + a = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'OpWithFutureDefaultAttr' + attr { key: 'default_int' value { i: 987 } } } + """), + return_elements=["A"], producer_op_list=producer_op_list) + self.assertEqual(987, a[0].get_attr("default_int")) + + +if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2d90ff5d5ab..940941a099b 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1317,7 +1317,11 @@ def _import_meta_graph_def(meta_graph_def): (i.e., no variables to restore). """ # Gathers the list of nodes we are interested in. - importer.import_graph_def(meta_graph_def.graph_def, name="") + producer_op_list = None + if meta_graph_def.meta_info_def.HasField("stripped_op_list"): + producer_op_list = meta_graph_def.meta_info_def.stripped_op_list + importer.import_graph_def(meta_graph_def.graph_def, name="", + producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items():