Add check for duplicate names in import_graph_def (#9605)
This commit is contained in:
parent
6adc38555f
commit
c9fccab57e
@ -275,6 +275,9 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
|
||||
# 1. Add operations without their inputs.
|
||||
for node in graph_def.node:
|
||||
# Check to see if this op's name matches a previously seen op
|
||||
if node.name in name_to_op:
|
||||
raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
|
||||
# Set any default attr values that aren't present.
|
||||
if node.op not in op_dict:
|
||||
raise ValueError('No op named %s in defined operations.' % node.op)
|
||||
|
@ -685,6 +685,17 @@ class ImportGraphDefTest(test.TestCase):
|
||||
self.assertEqual("return_elements must be a list of strings.",
|
||||
str(e.exception))
|
||||
|
||||
def testDuplicateOperationNames(self):
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError) as e:
|
||||
importer.import_graph_def(
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'Oi' }
|
||||
node { name: 'B' op: 'Oi' }
|
||||
node { name: 'A' op: 'Oi' }
|
||||
"""))
|
||||
self.assertEqual("Duplicate name 'A' in GraphDef.", str(e.exception))
|
||||
|
||||
def testWithExtensionAndAttr(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
c = constant_op.constant(5.0, dtype=dtypes.float32, name="c")
|
||||
|
Loading…
Reference in New Issue
Block a user