Fix DFS in quantize_graph.py ()

Mark nodes as visited before recursing into their children to avoid infinite loops (e.g. when quantizing graphs using tf.dynamic_rnn construction).
This commit is contained in:
Ilya Edrenkin 2016-08-26 21:02:02 +02:00 committed by Rasmus Munk Larsen
parent d19c7ec962
commit da7eb65ba5

View File

@ -66,12 +66,12 @@ flags.DEFINE_boolean("load_quantization_so", True,
def print_input_nodes(current_node, nodes_map, indent, already_visited):
print(" " * indent + current_node.op + ":" + current_node.name)
already_visited[current_node.name] = True
for input_node_name in current_node.input:
if input_node_name in already_visited:
continue
input_node = nodes_map[input_node_name]
print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
already_visited[current_node.name] = True
def create_node(op, name, inputs):
@ -350,13 +350,13 @@ class GraphRewriter(object):
def round_nodes_recursively(self, current_node):
"""The entry point for simple rounding quantization."""
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.round_nodes_recursively(input_node)
self.already_visited[current_node.name] = True
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
if any(current_node.op in s for s in nodes_to_quantize):
new_node = tf.NodeDef()
@ -381,13 +381,13 @@ class GraphRewriter(object):
def quantize_nodes_recursively(self, current_node):
"""The entry point for quantizing nodes to eight bit and back."""
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.quantize_nodes_recursively(input_node)
self.already_visited[current_node.name] = True
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
if any(current_node.op in s for s in nodes_to_quantize):
for input_name in current_node.input:
@ -448,13 +448,13 @@ class GraphRewriter(object):
def eightbitize_nodes_recursively(self, current_node):
"""The entry point for transforming a graph into full eight bit."""
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.eightbitize_nodes_recursively(input_node)
self.already_visited[current_node.name] = True
if current_node.op == "MatMul":
self.eightbitize_mat_mul_node(current_node)
elif current_node.op == "Conv2D":