Fix DFS in quantize_graph.py (#4010)
Mark nodes as visited before recursing into their children to avoid infinite loops (e.g. when quantizing graphs using tf.dynamic_rnn construction).
This commit is contained in:
parent
d19c7ec962
commit
da7eb65ba5
@ -66,12 +66,12 @@ flags.DEFINE_boolean("load_quantization_so", True,
|
||||
|
||||
def print_input_nodes(current_node, nodes_map, indent, already_visited):
|
||||
print(" " * indent + current_node.op + ":" + current_node.name)
|
||||
already_visited[current_node.name] = True
|
||||
for input_node_name in current_node.input:
|
||||
if input_node_name in already_visited:
|
||||
continue
|
||||
input_node = nodes_map[input_node_name]
|
||||
print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
|
||||
already_visited[current_node.name] = True
|
||||
|
||||
|
||||
def create_node(op, name, inputs):
|
||||
@ -350,13 +350,13 @@ class GraphRewriter(object):
|
||||
|
||||
def round_nodes_recursively(self, current_node):
|
||||
"""The entry point for simple rounding quantization."""
|
||||
self.already_visited[current_node.name] = True
|
||||
for input_node_name in current_node.input:
|
||||
input_node_name = node_name_from_input(input_node_name)
|
||||
if input_node_name in self.already_visited:
|
||||
continue
|
||||
input_node = self.nodes_map[input_node_name]
|
||||
self.round_nodes_recursively(input_node)
|
||||
self.already_visited[current_node.name] = True
|
||||
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
|
||||
if any(current_node.op in s for s in nodes_to_quantize):
|
||||
new_node = tf.NodeDef()
|
||||
@ -381,13 +381,13 @@ class GraphRewriter(object):
|
||||
|
||||
def quantize_nodes_recursively(self, current_node):
|
||||
"""The entry point for quantizing nodes to eight bit and back."""
|
||||
self.already_visited[current_node.name] = True
|
||||
for input_node_name in current_node.input:
|
||||
input_node_name = node_name_from_input(input_node_name)
|
||||
if input_node_name in self.already_visited:
|
||||
continue
|
||||
input_node = self.nodes_map[input_node_name]
|
||||
self.quantize_nodes_recursively(input_node)
|
||||
self.already_visited[current_node.name] = True
|
||||
nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
|
||||
if any(current_node.op in s for s in nodes_to_quantize):
|
||||
for input_name in current_node.input:
|
||||
@ -448,13 +448,13 @@ class GraphRewriter(object):
|
||||
|
||||
def eightbitize_nodes_recursively(self, current_node):
|
||||
"""The entry point for transforming a graph into full eight bit."""
|
||||
self.already_visited[current_node.name] = True
|
||||
for input_node_name in current_node.input:
|
||||
input_node_name = node_name_from_input(input_node_name)
|
||||
if input_node_name in self.already_visited:
|
||||
continue
|
||||
input_node = self.nodes_map[input_node_name]
|
||||
self.eightbitize_nodes_recursively(input_node)
|
||||
self.already_visited[current_node.name] = True
|
||||
if current_node.op == "MatMul":
|
||||
self.eightbitize_mat_mul_node(current_node)
|
||||
elif current_node.op == "Conv2D":
|
||||
|
Loading…
Reference in New Issue
Block a user