Change quantize_graph in eightbit mode to remove FakeQuantWithMinMaxVars

nodes and use the information provided by them to set the min/max values
on quantization-related nodes.

In eightbit mode, also changed how constant weights are quantized - instead of
doing it as a step after the main recursion, do it during the main recursion.
This allows the float inputs to FakeQuantWithMinMaxVars to be excluded from
quantization.

In eightbit mode, maintain more state in the stack during recursion.

Also change quantize reshape registration to register always and not use
TF_CALL_xyz; this matches other quantized ops.
Change: 137877226
This commit is contained in:
A. Unique TensorFlower 2016-11-01 14:10:52 -08:00 committed by TensorFlower Gardener
parent 1dadfdd276
commit 95f7166b88
3 changed files with 175 additions and 22 deletions

View File

@ -50,8 +50,8 @@ class QuantizedReshapeOp : public ReshapeOp {
.TypeConstraint<type>("T"), \
QuantizedReshapeOp)
TF_CALL_quint8(REGISTER_CPU_KERNEL);
TF_CALL_qint32(REGISTER_CPU_KERNEL);
REGISTER_CPU_KERNEL(::tensorflow::quint8);
REGISTER_CPU_KERNEL(::tensorflow::qint32);
#undef REGISTER_CPU_KERNEL

View File

@ -28,6 +28,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import numpy as np
import tensorflow as tf
@ -283,6 +284,11 @@ def quantize_weight_eightbit(input_node, quantization_mode):
return [quint8_const_node, min_node, max_node, dequantize_node]
EightbitizeRecursionState = collections.namedtuple(
"EightbitizeRecursionState", ["already_visited", "output_node_stack",
"merged_with_fake_quant"])
class GraphRewriter(object):
"""Takes a float graph, and rewrites it in quantized form."""
@ -316,6 +322,9 @@ class GraphRewriter(object):
else:
self.input_range = None
# Data that is valid only during the recursive call to rewrite the graph.
self.state = None
def create_nodes_map(self, graph):
"""Builds a mapping of node names to their defs from the graph."""
nodes_map = {}
@ -353,11 +362,12 @@ class GraphRewriter(object):
output_nodes = [self.nodes_map[output_node_name]
for output_node_name in output_node_names]
self.already_visited = {}
self.layers_eightbitized = []
self.state = EightbitizeRecursionState(already_visited={},
output_node_stack=[],
merged_with_fake_quant={})
for output_node in output_nodes:
self.eightbitize_nodes_recursively(output_node)
self.output_graph = self.quantize_weights(self.output_graph, b"MIN_FIRST")
self.state = None
if self.input_range:
self.add_output_graph_node(create_constant_node(
"quantized_input_min_value", self.input_range[0], tf.float32, []))
@ -477,20 +487,54 @@ class GraphRewriter(object):
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
self.add_output_graph_node(dequantize_node)
def should_merge_with_fake_quant_node(self):
"""Should the current node merge with self.state.output_node_stack[-1]?"""
if not self.state.output_node_stack: return False
top = self.state.output_node_stack[-1]
return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
def should_quantize_const(self, node):
if not self.state.output_node_stack: return False
top = self.state.output_node_stack[-1]
if not top[2]: return False
assert tf.as_dtype(node.attr["dtype"].type) == tf.float32, (
"Quantizing constant %s" % node.name)
return True
def eightbitize_nodes_recursively(self, current_node):
"""The entry point for transforming a graph into full eight bit."""
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
if current_node.name in self.state.already_visited:
if (self.should_merge_with_fake_quant_node() or
current_node.name in self.state.merged_with_fake_quant):
raise ValueError("Unsupported graph structure: output of node %s "
"is processed by a FakeQuant* node and should have "
"no other outputs.", current_node.name)
return
self.state.already_visited[current_node.name] = True
for i, input_node_name in enumerate(current_node.input):
quantize_input = False
if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
"AvgPool", "Relu", "Relu6",
"BatchNormWithGlobalNormalization"):
quantize_input = True
elif current_node.op == "Concat" and i > 0:
quantize_input = True
elif current_node.op == "Reshape" and i == 0:
quantize_input = True
self.state.output_node_stack.append((current_node, i, quantize_input))
input_node_name = node_name_from_input(input_node_name)
if input_node_name in self.already_visited:
continue
input_node = self.nodes_map[input_node_name]
self.eightbitize_nodes_recursively(input_node)
self.state.output_node_stack.pop()
if current_node.op == "MatMul":
self.eightbitize_mat_mul_node(current_node)
elif current_node.op == "Conv2D":
self.eightbitize_conv_node(current_node)
self.layers_eightbitized.append(current_node.name)
elif current_node.op == "BiasAdd":
self.eightbitize_bias_add_node(current_node)
elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
@ -508,11 +552,29 @@ class GraphRewriter(object):
elif (self.input_range and
current_node.op in ("Placeholder", "PlaceholderV2")):
self.eightbitize_placeholder_node(current_node)
elif current_node.op == "FakeQuantWithMinMaxVars":
# It will have been merged into the underlying node.
pass
elif current_node.op == "Const":
if self.should_quantize_const(current_node):
for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
self.add_output_graph_node(n)
else:
new_node = tf.NodeDef()
new_node.CopyFrom(current_node)
self.add_output_graph_node(new_node)
else:
new_node = tf.NodeDef()
new_node.CopyFrom(current_node)
self.add_output_graph_node(new_node)
if (self.should_merge_with_fake_quant_node() and
current_node.name not in self.state.merged_with_fake_quant):
raise ValueError(
"FakeQuant* node %s failed to merge with node %s of type %s" % (
self.state.output_node_stack[-1][0], current_node.name,
current_node.op))
def add_eightbit_prologue_nodes(self, original_node):
"""Adds input conversion nodes to handle quantizing the underlying node."""
namespace_prefix = original_node.name + "_eightbit"
@ -583,16 +645,26 @@ class GraphRewriter(object):
quantized_output_name, quantized_output_name + ":1",
quantized_output_name + ":2"
]
requant_range_node = create_node(
"RequantizationRange", original_node.name + "_eightbit_requant_range",
quantized_outputs)
set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
self.add_output_graph_node(requant_range_node)
min_max_inputs = None
if self.should_merge_with_fake_quant_node():
# Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
# Requantize.
fake_quant_node = self.state.output_node_stack[-1][0]
min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
assert original_node.name not in self.state.merged_with_fake_quant
self.state.merged_with_fake_quant[original_node.name] = True
else:
# Add a RequantizationRange node for finding the min and max values.
requant_range_node = create_node(
"RequantizationRange", original_node.name + "_eightbit_requant_range",
quantized_outputs)
set_attr_dtype(requant_range_node, "Tinput", tf.qint32)
self.add_output_graph_node(requant_range_node)
min_max_inputs = [requant_range_node.name + ":0",
requant_range_node.name + ":1"]
requantize_node = create_node(
"Requantize", original_node.name + "_eightbit_requantize",
(quantized_outputs +
[requant_range_node.name + ":0", requant_range_node.name + ":1"]))
quantized_outputs + min_max_inputs)
set_attr_dtype(requantize_node, "Tinput", tf.qint32)
set_attr_dtype(requantize_node, "out_type", tf.quint8)
self.add_output_graph_node(requantize_node)
@ -600,12 +672,20 @@ class GraphRewriter(object):
def add_dequantize_result_node(self, quantized_output_name,
original_node_name, min_tensor_index=1):
min_max_inputs = [
"%s:%s" % (quantized_output_name, min_tensor_index),
"%s:%s" % (quantized_output_name, (min_tensor_index + 1))]
dequantize_name = original_node_name
if self.should_merge_with_fake_quant_node():
fake_quant_node = self.state.output_node_stack[-1][0]
if original_node_name not in self.state.merged_with_fake_quant:
min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
self.state.merged_with_fake_quant[original_node_name] = True
dequantize_name = fake_quant_node.name
dequantize_node = create_node(
"Dequantize", dequantize_name,
[quantized_output_name,
"%s:%s" % (quantized_output_name, min_tensor_index),
"%s:%s" % (quantized_output_name, (min_tensor_index + 1))])
[quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
set_attr_dtype(dequantize_node, "T", tf.quint8)
set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
self.add_output_graph_node(dequantize_node)

View File

@ -160,7 +160,7 @@ def get_top_value(input_values):
return max_index, max_value
def test_graph(float_graph_def, input_map, output_names):
def test_graph(float_graph_def, input_map, output_names, log_graph=False):
"""Runs the float graph through the rewriter and tests the results."""
float_results = run_graph_def(float_graph_def, input_map,
[output_name + ":0"
@ -184,6 +184,9 @@ def test_graph(float_graph_def, input_map, output_names):
for expected, result in zip(float_results, eightbit_results):
assert are_tensors_near(expected, result, 1.0)
if log_graph:
tf.logging.info("8bit:\n%s", str(eightbit_graph_def))
# Test the weights_rounded mode. This uses the default bit_depth.
weights_rounded_rewriter = quantize_graph.GraphRewriter(
float_graph_def, "weights_rounded", quantized_input_range=None)
@ -580,6 +583,40 @@ class QuantizeGraphTest(tf.test.TestCase):
float_graph_def.node.extend([relu_node])
test_graph(float_graph_def, {}, [relu_name])
def test_relu_w_fake_quant_w_min_max_vars(self):
input_node = quantize_graph.create_constant_node(
"input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
dtype=tf.float32, shape=[1, 2, 6, 1])
relu_node = quantize_graph.create_node("Relu", "relu",
[input_node.name])
quantize_graph.set_attr_dtype(relu_node, "T", tf.float32)
min_node = quantize_graph.create_constant_node(
"min_bias_add", value=0, dtype=tf.float32, shape=[])
max_node = quantize_graph.create_constant_node(
"max_bias_add", value=12, dtype=tf.float32, shape=[])
fake_quant_node = quantize_graph.create_node(
"FakeQuantWithMinMaxVars", "fake_quant",
[relu_node.name, min_node.name, max_node.name])
float_graph_def = tf.GraphDef()
float_graph_def.node.extend([input_node, relu_node, min_node, max_node,
fake_quant_node])
test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
# Verify there is only one Quantize and one Requantize op.
eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
"eightbit",
quantized_input_range=None)
eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
ops = [node.op for node in eightbit_graph_def.node]
# No quantize since all inputs are const and can be quantized up-front.
self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
# One dequantize at the end.
self.assertEqual(1, ops.count("Dequantize"))
def test_relu6(self):
input_constant_name = "input_constant"
relu6_name = "relu6"
@ -720,6 +757,42 @@ class QuantizeGraphTest(tf.test.TestCase):
ops.count("QuantizeV2") + ops.count("Quantize"))
self.assertEqual(len(output_names), ops.count("Dequantize"))
def test_bias_add_w_fake_quant_w_min_max_vars(self):
input_node = quantize_graph.create_constant_node(
"input", value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
dtype=tf.float32, shape=[1, 1, 2, 5])
offset_node = quantize_graph.create_constant_node(
"offset", value=[1, 2, 3, 4, 5], dtype=tf.float32, shape=[5])
bias_add_node = quantize_graph.create_node(
"BiasAdd", "bias_add", [input_node.name, offset_node.name])
quantize_graph.set_attr_dtype(bias_add_node, "T", tf.float32)
min_node = quantize_graph.create_constant_node(
"min_bias_add", value=-.5, dtype=tf.float32, shape=[])
max_node = quantize_graph.create_constant_node(
"max_bias_add", value=15.5, dtype=tf.float32, shape=[])
fake_quant_node = quantize_graph.create_node(
"FakeQuantWithMinMaxVars", "fake_quant",
[bias_add_node.name, min_node.name, max_node.name])
float_graph_def = tf.GraphDef()
float_graph_def.node.extend([input_node, offset_node, bias_add_node,
min_node, max_node, fake_quant_node])
test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
# Verify there is only one Quantize and one Requantize op.
eightbit_rewriter = quantize_graph.GraphRewriter(float_graph_def,
"eightbit",
quantized_input_range=None)
eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
ops = [node.op for node in eightbit_graph_def.node]
# No quantize since all inputs are const and can be quantized up-front.
self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
# One dequantize at the end.
self.assertEqual(1, ops.count("Dequantize"))
def test_remove_redundant_quantization(self):
a_constant_name = "a_constant"
a_constant_min_name = "a_constant_min"