Change quantize_graph in eightbit mode to remove FakeQuantWithMinMaxVars
nodes and use the information provided by them to set the min/max values on quantization-related nodes. In eightbit mode, also changed how constant weights are quantized - instead of doing it as a step after the main recursion, do it during the main recursion. This allows the float inputs to FakeQuantWithMinMaxVars to be excluded from quantization. In eightbit mode, maintain more state in the stack during recursion. Also change quantize reshape registration to register always and not use TF_CALL_xyz; this matches other quantized ops. Change: 137877226
This commit is contained in:
parent
1dadfdd276
commit
95f7166b88
@ -50,8 +50,8 @@ class QuantizedReshapeOp : public ReshapeOp {
|
||||
.TypeConstraint<type>("T"), \
|
||||
QuantizedReshapeOp)
|
||||
|
||||
TF_CALL_quint8(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_qint32(REGISTER_CPU_KERNEL);
|
||||
REGISTER_CPU_KERNEL(::tensorflow::quint8);
|
||||
REGISTER_CPU_KERNEL(::tensorflow::qint32);
|
||||
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
|
@ -28,6 +28,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -283,6 +284,11 @@ def quantize_weight_eightbit(input_node, quantization_mode):
|
||||
return [quint8_const_node, min_node, max_node, dequantize_node]
|
||||
|
||||
|
||||
EightbitizeRecursionState = collections.namedtuple(
|
||||
"EightbitizeRecursionState", ["already_visited", "output_node_stack",
|
||||
"merged_with_fake_quant"])
|
||||
|
||||
|
||||
class GraphRewriter(object):
|
||||
"""Takes a float graph, and rewrites it in quantized form."""
|
||||
|
||||
@ -316,6 +322,9 @@ class GraphRewriter(object):
|
||||
else:
|
||||
self.input_range = None
|
||||
|
||||
# Data that is valid only during the recursive call to rewrite the graph.
|
||||
self.state = None
|
||||
|
||||
def create_nodes_map(self, graph):
|
||||
"""Builds a mapping of node names to their defs from the graph."""
|
||||
nodes_map = {}
|
||||
@ -353,11 +362,12 @@ class GraphRewriter(object):
|
||||
output_nodes = [self.nodes_map[output_node_name]
|
||||
for output_node_name in output_node_names]
|
||||
|
||||
self.already_visited = {}
|
||||
self.layers_eightbitized = []
|
||||
self.state = EightbitizeRecursionState(already_visited={},
|
||||
output_node_stack=[],
|
||||
merged_with_fake_quant={})
|
||||
for output_node in output_nodes:
|
||||
self.eightbitize_nodes_recursively(output_node)
|
||||
self.output_graph = self.quantize_weights(self.output_graph, b"MIN_FIRST")
|
||||
self.state = None
|
||||
if self.input_range:
|
||||
self.add_output_graph_node(create_constant_node(
|
||||
"quantized_input_min_value", self.input_range[0], tf.float32, []))
|
||||
@ -477,20 +487,54 @@ class GraphRewriter(object):
|
||||
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
|
||||
self.add_output_graph_node(dequantize_node)
|
||||
|
||||
def should_merge_with_fake_quant_node(self):
|
||||
"""Should the current node merge with self.state.output_node_stack[-1]?"""
|
||||
if not self.state.output_node_stack: return False
|
||||
top = self.state.output_node_stack[-1]
|
||||
return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
|
||||
|
||||
def should_quantize_const(self, node):
|
||||
if not self.state.output_node_stack: return False
|
||||
top = self.state.output_node_stack[-1]
|
||||
if not top[2]: return False
|
||||
assert tf.as_dtype(node.attr["dtype"].type) == tf.float32, (
|
||||
"Quantizing constant %s" % node.name)
|
||||
return True
|
||||
|
||||
def eightbitize_nodes_recursively(self, current_node):
|
||||
"""The entry point for transforming a graph into full eight bit."""
|
||||
self.already_visited[current_node.name] = True
|
||||
for input_node_name in current_node.input:
|
||||
if current_node.name in self.state.already_visited:
|
||||
if (self.should_merge_with_fake_quant_node() or
|
||||
current_node.name in self.state.merged_with_fake_quant):
|
||||
raise ValueError("Unsupported graph structure: output of node %s "
|
||||
"is processed by a FakeQuant* node and should have "
|
||||
"no other outputs.", current_node.name)
|
||||
return
|
||||
self.state.already_visited[current_node.name] = True
|
||||
|
||||
for i, input_node_name in enumerate(current_node.input):
|
||||
quantize_input = False
|
||||
if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
|
||||
"AvgPool", "Relu", "Relu6",
|
||||
"BatchNormWithGlobalNormalization"):
|
||||
quantize_input = True
|
||||
elif current_node.op == "Concat" and i > 0:
|
||||
quantize_input = True
|
||||
elif current_node.op == "Reshape" and i == 0:
|
||||
quantize_input = True
|
||||
|
||||
self.state.output_node_stack.append((current_node, i, quantize_input))
|
||||
|
||||
input_node_name = node_name_from_input(input_node_name)
|
||||
if input_node_name in self.already_visited:
|
||||
continue
|
||||
input_node = self.nodes_map[input_node_name]
|
||||
self.eightbitize_nodes_recursively(input_node)
|
||||
|
||||
self.state.output_node_stack.pop()
|
||||
|
||||
if current_node.op == "MatMul":
|
||||
self.eightbitize_mat_mul_node(current_node)
|
||||
elif current_node.op == "Conv2D":
|
||||
self.eightbitize_conv_node(current_node)
|
||||
self.layers_eightbitized.append(current_node.name)
|
||||
elif current_node.op == "BiasAdd":
|
||||
self.eightbitize_bias_add_node(current_node)
|
||||
elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
|
||||
@ -508,11 +552,29 @@ class GraphRewriter(object):
|
||||
elif (self.input_range and
|
||||
current_node.op in ("Placeholder", "PlaceholderV2")):
|
||||
self.eightbitize_placeholder_node(current_node)
|
||||
elif current_node.op == "FakeQuantWithMinMaxVars":
|
||||
# It will have been merged into the underlying node.
|
||||
pass
|
||||
elif current_node.op == "Const":
|
||||
if self.should_quantize_const(current_node):
|
||||
for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
|
||||
self.add_output_graph_node(n)
|
||||
else:
|
||||
new_node = tf.NodeDef()
|
||||
new_node.CopyFrom(current_node)
|
||||
self.add_output_graph_node(new_node)
|
||||
else:
|
||||
new_node = tf.NodeDef()
|
||||
new_node.CopyFrom(current_node)
|
||||
self.add_output_graph_node(new_node)
|
||||
|
||||
if (self.should_merge_with_fake_quant_node() and
|
||||
current_node.name not in self.state.merged_with_fake_quant):
|
||||
raise ValueError(
|
||||
"FakeQuant* node %s failed to merge with node %s of type %s" % (
|
||||
self.state.output_node_stack[-1][0], current_node.name,
|
||||
current_node.op))
|
||||
|
||||
def add_eightbit_prologue_nodes(self, original_node):
|
||||
"""Adds input conversion nodes to handle quantizing the underlying node."""
|
||||
namespace_prefix = original_node.name + "_eightbit"
|
||||
@ -583,16 +645,26 @@ class GraphRewriter(object):
|
||||
quantized_output_name, quantized_output_name + ":1",
|
||||
quantized_output_name + ":2"
|
||||
]
|
||||
requant_range_node = create_node(
|
||||
"RequantizationRange", original_node.name + "_eightbit_requant_range",
|
||||
quantized_outputs)
|
||||
set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
|
||||
self.add_output_graph_node(requant_range_node)
|
||||
|
||||
min_max_inputs = None
|
||||
if self.should_merge_with_fake_quant_node():
|
||||
# Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
|
||||
# Requantize.
|
||||
fake_quant_node = self.state.output_node_stack[-1][0]
|
||||
min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
|
||||
assert original_node.name not in self.state.merged_with_fake_quant
|
||||
self.state.merged_with_fake_quant[original_node.name] = True
|
||||
else:
|
||||
# Add a RequantizationRange node for finding the min and max values.
|
||||
requant_range_node = create_node(
|
||||
"RequantizationRange", original_node.name + "_eightbit_requant_range",
|
||||
quantized_outputs)
|
||||
set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
|
||||
self.add_output_graph_node(requant_range_node)
|
||||
min_max_inputs = [requant_range_node.name + ":0",
|
||||
requant_range_node.name + ":1"]
|
||||
requantize_node = create_node(
|
||||
"Requantize", original_node.name + "_eightbit_requantize",
|
||||
(quantized_outputs +
|
||||
[requant_range_node.name + ":0", requant_range_node.name + ":1"]))
|
||||
quantized_outputs + min_max_inputs)
|
||||
set_attr_dtype(requantize_node, "Tinput", tf.qint32)
|
||||
set_attr_dtype(requantize_node, "out_type", tf.quint8)
|
||||
self.add_output_graph_node(requantize_node)
|
||||
@ -600,12 +672,20 @@ class GraphRewriter(object):
|
||||
|
||||
def add_dequantize_result_node(self, quantized_output_name,
|
||||
original_node_name, min_tensor_index=1):
|
||||
min_max_inputs = [
|
||||
"%s:%s" % (quantized_output_name, min_tensor_index),
|
||||
"%s:%s" % (quantized_output_name, (min_tensor_index + 1))]
|
||||
dequantize_name = original_node_name
|
||||
if self.should_merge_with_fake_quant_node():
|
||||
fake_quant_node = self.state.output_node_stack[-1][0]
|
||||
if original_node_name not in self.state.merged_with_fake_quant:
|
||||
min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
|
||||
self.state.merged_with_fake_quant[original_node_name] = True
|
||||
dequantize_name = fake_quant_node.name
|
||||
|
||||
dequantize_node = create_node(
|
||||
"Dequantize", dequantize_name,
|
||||
[quantized_output_name,
|
||||
"%s:%s" % (quantized_output_name, min_tensor_index),
|
||||
"%s:%s" % (quantized_output_name, (min_tensor_index + 1))])
|
||||
[quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
|
||||
set_attr_dtype(dequantize_node, "T", tf.quint8)
|
||||
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
|
||||
self.add_output_graph_node(dequantize_node)
|
||||
|
@ -160,7 +160,7 @@ def get_top_value(input_values):
|
||||
return max_index, max_value
|
||||
|
||||
|
||||
def test_graph(float_graph_def, input_map, output_names):
|
||||
def test_graph(float_graph_def, input_map, output_names, log_graph=False):
|
||||
"""Runs the float graph through the rewriter and tests the results."""
|
||||
float_results = run_graph_def(float_graph_def, input_map,
|
||||
[output_name + ":0"
|
||||
@ -184,6 +184,9 @@ def test_graph(float_graph_def, input_map, output_names):
|
||||
for expected, result in zip(float_results, eightbit_results):
|
||||
assert are_tensors_near(expected, result, 1.0)
|
||||
|
||||
if log_graph:
|
||||
tf.logging.info("8bit:\n%s", str(eightbit_graph_def))
|
||||
|
||||
# Test the weights_rounded mode. This uses the default bit_depth.
|
||||
weights_rounded_rewriter = quantize_graph.GraphRewriter(
|
||||
float_graph_def, "weights_rounded", quantized_input_range=None)
|
||||
@ -580,6 +583,40 @@ class QuantizeGraphTest(tf.test.TestCase):
|
||||
float_graph_def.node.extend([relu_node])
|
||||
test_graph(float_graph_def, {}, [relu_name])
|
||||
|
||||
def test_relu_w_fake_quant_w_min_max_vars(self):
|
||||
input_node = quantize_graph.create_constant_node(
|
||||
"input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
dtype=tf.float32, shape=[1, 2, 6, 1])
|
||||
relu_node = quantize_graph.create_node("Relu", "relu",
|
||||
[input_node.name])
|
||||
quantize_graph.set_attr_dtype(relu_node, "T", tf.float32)
|
||||
|
||||
min_node = quantize_graph.create_constant_node(
|
||||
"min_bias_add", value=0, dtype=tf.float32, shape=[])
|
||||
max_node = quantize_graph.create_constant_node(
|
||||
"max_bias_add", value=12, dtype=tf.float32, shape=[])
|
||||
fake_quant_node = quantize_graph.create_node(
|
||||
"FakeQuantWithMinMaxVars", "fake_quant",
|
||||
[relu_node.name, min_node.name, max_node.name])
|
||||
|
||||
float_graph_def = tf.GraphDef()
|
||||
float_graph_def.node.extend([input_node, relu_node, min_node, max_node,
|
||||
fake_quant_node])
|
||||
test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
|
||||
|
||||
# Verify there is only one Quantize and one Requantize op.
|
||||
eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
|
||||
"eightbit",
|
||||
quantized_input_range=None)
|
||||
eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
|
||||
|
||||
ops = [node.op for node in eightbit_graph_def.node]
|
||||
# No quantize since all inputs are const and can be quantized up-front.
|
||||
self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
|
||||
|
||||
# One dequantize at the end.
|
||||
self.assertEqual(1, ops.count("Dequantize"))
|
||||
|
||||
def test_relu6(self):
|
||||
input_constant_name = "input_constant"
|
||||
relu6_name = "relu6"
|
||||
@ -720,6 +757,42 @@ class QuantizeGraphTest(tf.test.TestCase):
|
||||
ops.count("QuantizeV2") + ops.count("Quantize"))
|
||||
self.assertEqual(len(output_names), ops.count("Dequantize"))
|
||||
|
||||
def test_bias_add_w_fake_quant_w_min_max_vars(self):
|
||||
input_node = quantize_graph.create_constant_node(
|
||||
"input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
dtype=tf.float32, shape=[1, 1, 2, 5])
|
||||
offset_node = quantize_graph.create_constant_node(
|
||||
"offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5])
|
||||
bias_add_node = quantize_graph.create_node(
|
||||
"BiasAdd", "bias_add", [input_node.name, offset_node.name])
|
||||
quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
|
||||
|
||||
min_node = quantize_graph.create_constant_node(
|
||||
"min_bias_add", value=-.5, dtype=tf.float32, shape=[])
|
||||
max_node = quantize_graph.create_constant_node(
|
||||
"max_bias_add", value=15.5, dtype=tf.float32, shape=[])
|
||||
fake_quant_node = quantize_graph.create_node(
|
||||
"FakeQuantWithMinMaxVars", "fake_quant",
|
||||
[bias_add_node.name, min_node.name, max_node.name])
|
||||
|
||||
float_graph_def = tf.GraphDef()
|
||||
float_graph_def.node.extend([input_node, offset_node, bias_add_node,
|
||||
min_node, max_node, fake_quant_node])
|
||||
test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
|
||||
|
||||
# Verify there is only one Quantize and one Requantize op.
|
||||
eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
|
||||
"eightbit",
|
||||
quantized_input_range=None)
|
||||
eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
|
||||
|
||||
ops = [node.op for node in eightbit_graph_def.node]
|
||||
# No quantize since all inputs are const and can be quantized up-front.
|
||||
self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
|
||||
|
||||
# One dequantize at the end.
|
||||
self.assertEqual(1, ops.count("Dequantize"))
|
||||
|
||||
def test_remove_redundant_quantization(self):
|
||||
a_constant_name = "a_constant"
|
||||
a_constant_min_name = "a_constant_min"
|
||||
|
Loading…
Reference in New Issue
Block a user