check invalid string type for dest_nodes in extract_sub_graph (#13057)
* BUG: check str type * TST: add unit test * CLN: remove list check * CLN: use warning * CLN: 2 indent * CLN: raise TypeError if not list * CLN: check string only
This commit is contained in:
parent
d2d42ee8b3
commit
fe3a2e65cc
@ -21,6 +21,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
import copy
|
||||
import re
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
@ -123,6 +124,9 @@ def extract_sub_graph(graph_def, dest_nodes):
|
||||
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
|
||||
|
||||
if isinstance(dest_nodes, six.string_types):
|
||||
raise TypeError("dest_nodes must be a list.")
|
||||
|
||||
edges = {} # Keyed by the dest node name.
|
||||
name_to_node_map = {} # Keyed by node name.
|
||||
|
||||
|
@ -188,6 +188,13 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
self.assertEqual("n3", sub_graph.node[2].name)
|
||||
self.assertEqual("n5", sub_graph.node[3].name)
|
||||
|
||||
def testExtractSubGraphWithInvalidDestNodes(self):
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
n1 = graph_def.node.add()
|
||||
n1.name = "n1"
|
||||
with self.assertRaisesRegexp(TypeError, "must be a list"):
|
||||
graph_util.extract_sub_graph(graph_def, "n1")
|
||||
|
||||
def testConvertVariablesToConstsWithFunctions(self):
|
||||
@function.Defun(dtypes.float32)
|
||||
def plus_one(x):
|
||||
|
Loading…
Reference in New Issue
Block a user