Add check for duplicate names in import_graph_def (#9605)

This commit is contained in:
Sam Abrahams 2017-05-03 14:20:06 -07:00 committed by Vijay Vasudevan
parent 6adc38555f
commit c9fccab57e
2 changed files with 14 additions and 0 deletions

View File

@ -275,6 +275,9 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
# 1. Add operations without their inputs.
for node in graph_def.node:
# Check to see if this op's name matches a previously seen op
if node.name in name_to_op:
raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
# Set any default attr values that aren't present.
if node.op not in op_dict:
raise ValueError('No op named %s in defined operations.' % node.op)

View File

@ -685,6 +685,17 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual("return_elements must be a list of strings.",
str(e.exception))
def testDuplicateOperationNames(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'Oi' }
node { name: 'B' op: 'Oi' }
node { name: 'A' op: 'Oi' }
"""))
self.assertEqual("Duplicate name 'A' in GraphDef.", str(e.exception))
def testWithExtensionAndAttr(self):
with ops.Graph().as_default() as g:
c = constant_op.constant(5.0, dtype=dtypes.float32, name="c")