check invalid string type for dest_nodes in extract_sub_graph (#13057)

* BUG: check str type

* TST: add unit test

* CLN: remove list check

* CLN: use warning

* CLN: 2 indent

* CLN: raise TypeError if not list

* CLN: check string only
This commit is contained in:
Yan Facai (颜发才) 2017-09-25 02:01:09 +08:00 committed by drpngx
parent d2d42ee8b3
commit fe3a2e65cc
2 changed files with 11 additions and 0 deletions

View File

@ -21,6 +21,7 @@ from __future__ import division
from __future__ import print_function
import copy
import re
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
@ -123,6 +124,9 @@ def extract_sub_graph(graph_def, dest_nodes):
if not isinstance(graph_def, graph_pb2.GraphDef):
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
if isinstance(dest_nodes, six.string_types):
raise TypeError("dest_nodes must be a list.")
edges = {} # Keyed by the dest node name.
name_to_node_map = {} # Keyed by node name.

View File

@ -188,6 +188,13 @@ class DeviceFunctionsTest(test.TestCase):
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
def testExtractSubGraphWithInvalidDestNodes(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
with self.assertRaisesRegexp(TypeError, "must be a list"):
graph_util.extract_sub_graph(graph_def, "n1")
def testConvertVariablesToConstsWithFunctions(self):
@function.Defun(dtypes.float32)
def plus_one(x):