Generalize quantization rewrite to not rely on names.

It should now work with most graphs regardless if they were built slim or not.

PiperOrigin-RevId: 184889280
This commit is contained in:
Suharsh Sivakumar 2018-02-07 13:47:02 -08:00 committed by TensorFlower Gardener
parent ef16b9cc9a
commit 190b918c8c
5 changed files with 299 additions and 470 deletions

View File

@ -75,7 +75,9 @@ py_library(
":graph_matcher",
":input_to_ops",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
@ -83,6 +85,7 @@ py_library(
"//tensorflow/python:nn_ops",
"//tensorflow/python:ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
)
@ -162,7 +165,6 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variables",
@ -174,7 +176,7 @@ py_library(
srcs = ["python/quantize.py"],
srcs_version = "PY2AND3",
deps = [
":common",
":graph_matcher",
":input_to_ops",
":quant_ops",
"//tensorflow/contrib/graph_editor:graph_editor_py",

View File

@ -53,7 +53,7 @@ def LastValueQuantize(inputs,
init_max=6.0,
updates_collection=ops.GraphKeys.UPDATE_OPS,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
scope=None,
name_prefix='LastValueQuant',
reuse=None,
is_training=True,
num_bits=8,
@ -73,7 +73,7 @@ def LastValueQuantize(inputs,
computation.
vars_collection: (Optional) collection where to store variables for
quantization interval ends.
scope: Optional scope for variable_scope.
name_prefix: name_prefix for created nodes.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
is_training: Whether the op is applied to a training or eval graph.
@ -84,13 +84,13 @@ def LastValueQuantize(inputs,
a tensor containing quantized values.
"""
with variable_scope.variable_scope(
scope, 'LastValueQuantize', values=[inputs], reuse=reuse):
None, default_name=name_prefix, values=[inputs], reuse=reuse):
input_shape = inputs.get_shape()
input_dim = len(input_shape)
if per_channel:
# Only support quantizing 1-, 2- and 4-dimensional tensors.
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
' scope: %s' % (input_shape, scope))
' scope: %s' % (input_shape, name_prefix))
min_max_shape = [input_shape[-1]]
else:
min_max_shape = []
@ -165,7 +165,7 @@ def MovingAvgQuantize(inputs,
ema_decay=0.999,
updates_collection=ops.GraphKeys.UPDATE_OPS,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
scope=None,
name_prefix='MovingAvgQuantize',
reuse=None,
is_training=True,
num_bits=8,
@ -186,7 +186,7 @@ def MovingAvgQuantize(inputs,
computation.
vars_collection: (Optional) collection where to store variables for
quantization interval ends.
scope: Optional scope for variable_scope.
name_prefix: name_prefix for created nodes.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
is_training: Whether the op is applied to a training or eval graph.
@ -197,13 +197,13 @@ def MovingAvgQuantize(inputs,
a tensor containing quantized values.
"""
with variable_scope.variable_scope(
scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse):
None, default_name=name_prefix, values=[inputs], reuse=reuse):
input_shape = inputs.get_shape()
input_dim = len(input_shape)
if per_channel:
# Only support quantizing 1-, 2- and 4-dimensional tensors.
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
' scope: %s' % (input_shape, scope))
' scope: %s' % (input_shape, name_prefix))
min_max_shape = [input_shape[-1]]
else:
min_max_shape = []

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logic to update a Tensorflow model graph with quantization operations."""
"""Logic to update a TensorFlow model graph with quantization operations."""
from __future__ import absolute_import
from __future__ import division
@ -20,7 +20,7 @@ from __future__ import print_function
import re
from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import ops
@ -28,30 +28,29 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import training_util
# Operation types used to select operations of interest.
# Quantizable operation types that are supported by the quantization rewrite.
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
# Custom key for storing and retrieving update ops used by quantizing nodes.
_UPDATE_QUANT_OPS = 'update_quant_ops'
# Activations that are supported by the quantization rewrite.
_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'}
# Weight types that are supported by the quantization rewrite.
# TODO(suharshs): Add support for ResourceVariable.
_WEIGHT_TYPES = {'Variable', 'VariableV2'}
def Quantize(graph,
weight_bits=8,
weight_narrow_range=False,
activation_bits=8,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
is_training=True,
quantize_folded_weights_use_ema=False):
is_training=True):
"""Updates graph with quantization operations.
Args:
graph: Graph to modify.
weight_bits: Number of bits to use for quantizing weights.
weight_narrow_range: Whether to use a more efficient narrow range for
weights quantization. With weight_narrow_range true, the range is
[1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
activation_bits: Number of bits to use for quantizing activations.
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
quantization intervals for quantizing activations (see here about EMA:
@ -62,345 +61,274 @@ def Quantize(graph,
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
is_training: (Optional) Whether quantizing training graph or eval graph.
quantize_folded_weights_use_ema: (Optional, default False) Whether to
quantize weights after batchnorm-folding with exponential average
quantization.
Raises:
ValueError: When quantization fails.
"""
context = _QuantizeContext(graph, weight_bits, weight_narrow_range,
activation_bits, ema_decay, quant_delay,
vars_collection, is_training,
quantize_folded_weights_use_ema)
graph_ops = graph.get_operations()
# Filter out backprop and summary related operations, leave only interesting
# op types.
def _IsInterestingOpWithWeights(op):
return (op.type in _QUANTIZABLE_TYPES and
not op.name.startswith(common.SKIPPED_PREFIXES))
for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)):
if op.name.endswith('/depthwise'):
# Separable convolution may consist of 2 convolution nodes. If so, skip
# .../depthwise and only quantize the top one.
separable_conv = context.GetOperationByNameDontThrow(
op.name[:-len('/depthwise')])
if separable_conv and separable_conv.type == 'Conv2D':
continue
# Quantize add ops that come after Conv2D or DepthwiseConv2dNative.
if op.type in ['Conv2D', 'DepthwiseConv2dNative']:
add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
if add_context_re is not None:
context.add_contexts.add(add_context_re.group(1))
if not op.name.endswith('_Fold'):
folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold')
# Do nothing if found, it will be quantized when it is iterated over.
if not folded_op:
context.QuantizeOpWithWeights(op, folded=False)
else:
context.QuantizeOpWithWeights(op, folded=True)
context.QuantizeAddContexts()
# Once all quantization ops have been inserted in the graph, collect update
# ops for their variables and modify the TF Slim update barrier (see
# https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py)
# to depend on them.
try:
update_barrier = graph.get_operation_by_name('update_barrier')
except KeyError:
# In evaluation graph, this barrier may not exist.
return None
update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS)
graph_editor.add_control_inputs(update_barrier, update_quant_ops)
class _QuantizeContext(object):
"""Context holds references needed for quantization."""
def __init__(self,
graph,
weight_bits,
weight_narrow_range,
activation_bits,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
is_training=True,
quantize_folded_weights_use_ema=False):
"""Initializes context to hold references needed for quantization.
Args:
graph: Graph to modify.
weight_bits: Number of bits to use for quantizing weights.
weight_narrow_range: Whether to use a more efficient narrow range for
weights quantization. With weight_narrow_range true, the range is
[1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
activation_bits: Number of bits to use for quantizing activations.
ema_decay: (Optional) Float, EMA decay parameter.
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
is_training: (Optional) Whether quantizing training or eval graph.
quantize_folded_weights_use_ema: (Optional, default False) Whether to
quantize weights after batchnorm-folding with exponential average
quantization.
"""
self.graph = graph
self.weight_bits = weight_bits
self.weight_narrow_range = weight_narrow_range
self.activation_bits = activation_bits
self.ema_decay = ema_decay
self.quant_delay = quant_delay
self.vars_collection = vars_collection
self.is_training = is_training
self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema
self.input_to_ops_map = input_to_ops.InputToOps(graph)
self.add_contexts = set()
def QuantizeAddContexts(self):
"""Quantizes all add ops in self.add_contexts."""
# Loop through sorted self.add_contexts so that op creation is
# deterministic. This is needed when using multiple worker replicas so that
# the ops can be initialized consistently.
for add_context in sorted(self.add_contexts):
add_op = self.GetOperationByNamesDontThrow([
add_context + '/Add', add_context + '/add'])
if add_op is not None:
self._InsertQuantOp(
add_context,
add_op,
self.input_to_ops_map.ConsumerOperations(add_op),
name='add_quant',
moving_avg=True,
bits=self.activation_bits,
narrow_range=False)
def QuantizeOpWithWeights(self, op, folded):
"""Quantizes around the specific operation with or without batch norm.
Args:
op: Operation to quantize.
folded: Operation has been folded and needs special handling if True.
Raises:
ValueError: When quantization fails.
"""
# Op name component before the last slash will be used as context.
context = re.search(r'^(.*)/([^/]+)', op.name).group(1)
# Quantize weights.
if folded:
producer_op = self.graph.get_operation_by_name(context + '/mul_fold')
else:
try:
input_idx = next(i for i, v in enumerate(op.inputs)
if '/weights/' in v.name or
'/depthwise_weights' in v.name)
except StopIteration:
raise ValueError('No inputs to quantize for op: %s' % op)
producer_op = op.inputs[input_idx].op
# If batch norm is used, the folded weights depend on the batch std, hence
# it is sensible to use EMA during training to smooth out the noise. This is
# controlled by the flag quantize_folded_weights_use_ema. Its default is
# False for backward compatibility.
# If there is no batch norm, weights do not depend on the batch and using
# the latest value of min and max is more efficient.
weight_use_ema = folded and self.quantize_folded_weights_use_ema
self._InsertQuantOp(
input_to_ops_map = input_to_ops.InputToOps(graph)
for layer_match in _FindLayersToQuantize(graph):
# Quantize the weights.
context = _GetContextFromOp(layer_match.layer_op)
_InsertQuantOp(
context,
producer_op, [op],
layer_match.weight_tensor.op, [layer_match.layer_op],
name='weights_quant',
moving_avg=weight_use_ema,
delay_requested=weight_use_ema,
bits=self.weight_bits,
narrow_range=self.weight_narrow_range)
moving_avg=False,
bits=weight_bits,
ema_decay=ema_decay,
quant_delay=quant_delay,
is_training=is_training,
narrow_range=True,
vars_collection=vars_collection)
# Important: do not quantize biases here. During inference they are
# quantized to 32 bits, which is much finer than 8 bit quantization and
# depends on weight and input activation ranges.
# Find activation and (optionally) Add operations to quantize.
activation_op, add_op, add_context = self._GetReluAndAddOperations(context,
op)
if add_op:
original_context = context
context = add_context
# Quantize activation outputs.
consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op)
self._InsertQuantOp(
context,
activation_op,
# Quantize the activations.
consumer_ops = input_to_ops_map.ConsumerOperations(
layer_match.activation_op)
add_context = context
if layer_match.bypass_op:
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
_InsertQuantOp(
add_context,
layer_match.activation_op,
consumer_ops,
name='act_quant',
moving_avg=True,
init_min=0.0,
bits=self.activation_bits,
narrow_range=False)
ema_decay=ema_decay,
quant_delay=quant_delay,
bits=activation_bits,
vars_collection=vars_collection)
# When a bypass connection was found, also quantize Add op input.
if add_op:
def _QuantizeAddInput(add_input):
if folded:
return add_input.op.name.endswith('/add_fold')
else:
return add_input.op.name.startswith(original_context + '/')
# Quantize the inputs and output to the bypass (if it exists). The input to
# the bypass is the bias add, and the output is the activation.
if layer_match.bypass_op is not None:
_InsertQuantOp(
context,
layer_match.bias_add_op, [layer_match.bypass_op],
name='conv_quant',
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits)
_InsertQuantOp(
add_context,
layer_match.bypass_op,
input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
name='add_quant',
moving_avg=True,
bits=activation_bits)
for add_input in add_op.inputs:
if _QuantizeAddInput(add_input):
self._InsertQuantOp(
original_context,
add_input.op, [add_op],
name='conv_quant',
moving_avg=True,
bits=self.activation_bits,
narrow_range=False)
def _GetReluAndAddOperations(self, context, op):
"""Looks up a Relu* and Add operations in given context.
def _FindLayersToQuantize(graph):
"""Matches layers in graph to quantize.
Args:
context: Context where to look for operations.
op: Operation to quantize.
Args:
graph: Graph to perform match on.
Returns:
A triplet (Operation, Operation, string), the first element is an end
point operation, the second is Add operation (optional), the third element
is string context where the Add operation was found (optional).
Yields:
_LayerMatches.
"""
input_pattern = graph_matcher.OpTypePattern('*')
weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES))
weight_pattern = graph_matcher.OpTypePattern(
'Identity', inputs=[weight_var_pattern])
Raises:
ValueError: When operations cannot be found.
"""
activation_op = common.GetEndpointActivationOp(self.graph, context)
if activation_op:
return activation_op, None, None
folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
if '/' in context:
# If no activation op is there, look for them one level up.
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
activation_op = common.GetEndpointActivationOp(self.graph, add_context)
if not activation_op:
# Still no Relu, can happen on the top layer, just find the next node up,
# make sure it is BiasAdd.
consumers = [c for outp in op.outputs for c in outp.consumers()]
if len(consumers) != 1 or consumers[0].type != 'BiasAdd':
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
return consumers[0], None, None
if add_context:
add_op = self.GetOperationByNamesDontThrow([
add_context + '/Add', add_context + '/add'])
return activation_op, add_op, add_context
else:
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
# The weights inputs to the layer operation can either be from the Variable or
# the folded weight (Mul).
layer_pattern = graph_matcher.OpTypePattern(
'|'.join(_QUANTIZABLE_TYPES),
inputs=[
input_pattern,
graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern])
])
def GetOperationByNameDontThrow(self, name):
"""Returns an Operation with the given name.
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
'Add', inputs=[folded_bias_mul_pattern,
graph_matcher.OpTypePattern('*')])
folded_bias_add_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
post_layer_op_correction_pattern,
graph_matcher.OpTypePattern('*')
])
Args:
name: Name of Operation to return.
bias_add_pattern = graph_matcher.OpTypePattern(
'Add|BiasAdd', inputs=[layer_pattern, '*'])
Returns:
The Operation with the given name. None if the name does not correspond to
any operation in the graph.
"""
try:
return self.graph.get_operation_by_name(name)
except KeyError:
return None
# The bias can come from the bias add or the folded bias add.
bypass_pattern_a = graph_matcher.OpTypePattern(
'Add',
inputs=[
graph_matcher.OneofPattern(
[bias_add_pattern, folded_bias_add_pattern]), '*'
])
bypass_pattern_b = graph_matcher.OpTypePattern(
'Add',
inputs=[
'*',
graph_matcher.OneofPattern(
[bias_add_pattern, folded_bias_add_pattern])
])
def GetOperationByNamesDontThrow(self, names):
"""Returns an Operation with one of the given names.
# The input to the activation can come from bias add, fold bias add or the
# bypasses.
activation_pattern = graph_matcher.OpTypePattern(
'|'.join(_ACTIVATION_TYPES),
inputs=[
graph_matcher.OneofPattern([
bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
bypass_pattern_b
])
])
Args:
names: Names of Operation to return.
layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
for match_result in layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(folded_weight_pattern)
activation_op = match_result.get_op(activation_pattern)
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
bypass_op = match_result.get_op(bypass_pattern_a)
if bypass_op is None:
bypass_op = match_result.get_op(bypass_pattern_b)
yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
bias_add_op)
Returns:
The Operation with one of the given names. None if none of the names
corresponds to any operation in the graph.
"""
for name in names:
op = self.GetOperationByNameDontThrow(name)
if op is not None:
return op
return None
def _InsertQuantOp(
self,
context,
producer,
consumers,
name,
moving_avg=True,
init_min=-6.0,
init_max=6.0,
delay_requested=True,
bits=8,
narrow_range=False,):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
class _LayerMatch(object):
"""Contains all information related to a matched Layer."""
Args:
context: Context where producer and consumer operations are nested.
producer: Producer operation of the pairs where quantization will be
inserted.
consumers: Consumer operations of the pairs.
name: Name for the new quantization op within the context.
moving_avg: Specifies whether to use exponential moving average or just
the last value seen.
init_min: Starting minimum value for the new quantization op.
init_max: Starting maximum value for the new quantization op.
delay_requested: If true, implement quantization delay where needed.
False value explicitly disables delay quantization everywhere.
bits: Number of bits to use for quantization, must be between 2 and 8.
narrow_range: Whether to use the narrow quantization range
def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
bias_add_op):
self._layer_op = layer_op
self._weight_tensor = weight_tensor
self._activation_op = activation_op
self._bypass_op = bypass_op
self._bias_add_op = bias_add_op
@property
def layer_op(self):
return self._layer_op
@property
def weight_tensor(self):
return self._weight_tensor
@property
def activation_op(self):
return self._activation_op
@property
def bypass_op(self):
return self._bypass_op
@property
def bias_add_op(self):
return self._bias_add_op
def _InsertQuantOp(context,
producer,
consumers,
name,
moving_avg=True,
init_min=-6.0,
init_max=6.0,
bits=8,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
is_training=True,
narrow_range=False):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
context: Context w,here producer and consumer operations are nested.
producer: Producer operation of the pairs where quantization will be
inserted.
consumers: Consumer operations of the pairs.
name: Name for the new quantization op within the context.
moving_avg: Specifies whether to use exponential moving average or just
the last value seen.
init_min: Starting minimum value for the new quantization op.
init_max: Starting maximum value for the new quantization op.
bits: Number of bits to use for quantization, must be between 2 and 8.
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
quantization intervals for quantizing activations (see here about EMA:
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
is_training: (Optional) Whether quantizing training graph or eval graph.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
scope = context + '/' + name
inputs = producer.outputs[0]
if moving_avg:
quant = (quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=self.ema_decay,
is_training=self.is_training,
num_bits=bits,
narrow_range=narrow_range,
updates_collection=_UPDATE_QUANT_OPS,
vars_collection=self.vars_collection,
scope=scope))
else:
quant = (quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=self.is_training,
num_bits=bits,
narrow_range=narrow_range,
updates_collection=_UPDATE_QUANT_OPS,
vars_collection=self.vars_collection,
scope=scope))
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
name_prefix = _AddContextToName(context, name)
inputs = producer.outputs[0]
if moving_avg:
quant = (
quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=ema_decay,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
else:
quant = (
quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
if delay_requested and self.quant_delay and self.quant_delay > 0:
activate_quant = math_ops.greater_equal(
training_util.get_or_create_global_step(),
self.quant_delay,
name=scope + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=scope + '/delayed_quant')
if quant_delay and quant_delay > 0:
activate_quant = math_ops.greater_equal(
training_util.get_or_create_global_step(),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
nodes_modified_count = graph_editor.reroute_ts(
[quant], [inputs], can_modify=consumers)
if nodes_modified_count != len(consumers):
raise ValueError('Some inputs not quantized for ops: [%s]' %
', '.join([consumer.name for consumer in consumers]))
nodes_modified_count = graph_editor.reroute_ts(
[quant], [inputs], can_modify=consumers)
if nodes_modified_count != len(consumers):
raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
def _GetContextFromOp(op):
"""Gets the root context name from the op name."""
context_re = re.search(r'^(.*)/([^/]+)', op.name)
if context_re:
return context_re.group(1)
return ''
def _AddContextToName(context, name):
"""Adds the context to the name if it exists."""
if not context:
return name
return context + '/' + name

View File

@ -101,7 +101,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + '/Conv2D'
output_op_name = (
scope + '/weights_quant/delayed_quant/Switch_1'
if delay else scope + '/Conv2D')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
@ -176,7 +178,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + '/MatMul'
output_op_name = (
scope + '/weights_quant/delayed_quant/Switch_1'
if delay else scope + '/MatMul')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
@ -252,7 +256,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
scope + '/depthwise_weights/read'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + '/depthwise'
output_op_name = (
scope + '/weights_quant/delayed_quant/Switch_1'
if delay else scope + '/depthwise')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
@ -316,6 +322,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
for params in parameters_list:
test_fn(params[0], params[1], params[2], params[3], params[4])
def testQuantize_Conv2dWithBatchNorm(self):
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
@ -329,39 +338,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
"""
self._testQuantize_Conv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=True)
self._testQuantize_Conv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=False)
def testQuantize_Conv2dWithBatchNorm(self):
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
use_ema):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
activation: Callable that returns an Operation, a factory method for the
Activation.
activation_op_name: String, name of the Activation operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights.
"""
graph = ops.Graph()
with graph.as_default():
training.create_global_step(graph)
@ -394,23 +370,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
quantize.Quantize(graph, quant_delay=delay)
quantization_node_name = 'FakeQuantWithMinMaxVars'
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
scope + '/weights_quant/' + ('AssignMinEma'
if use_ema else 'AssignMinLast'),
scope + '/weights_quant/' + ('AssignMaxEma'
if use_ema else 'AssignMaxLast'),
scope + '/mul_fold'
scope + '/weights_quant/' + 'AssignMinLast',
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
if (delay and use_ema) else '/Conv2D_Fold')
if delay else '/Conv2D_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
@ -438,6 +410,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_FCWithBatchNorm(self):
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
@ -451,39 +426,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
"""
self._testQuantize_FCWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=True)
self._testQuantize_FCWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=False)
def testQuantize_FCWithBatchNorm(self):
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
use_ema):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
activation: Callable that returns an Operation, a factory method for the
Activation.
activation_op_name: String, name of the Activation operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights.
"""
graph = ops.Graph()
with graph.as_default():
training.create_global_step(graph)
@ -513,23 +455,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
quantize.Quantize(graph, quant_delay=delay)
quantization_node_name = 'FakeQuantWithMinMaxVars'
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
scope + '/weights_quant/' + ('AssignMinEma'
if use_ema else 'AssignMinLast'),
scope + '/weights_quant/' + ('AssignMaxEma'
if use_ema else 'AssignMaxLast'),
scope + '/mul_fold'
scope + '/weights_quant/' + 'AssignMinLast',
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
if delay and use_ema else '/MatMul_Fold')
if delay else '/MatMul_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
@ -557,6 +495,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
self._RunBatchNormTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
fused_batch_norm):
@ -571,40 +513,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
"""
self._testQuantize_DepthwiseConv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=True)
self._testQuantize_DepthwiseConv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=False)
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
self._RunBatchNormTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
def _testQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
fused_batch_norm, use_ema):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
activation: Callable that returns an Operation, a factory method for the
Activation.
activation_op_name: String, name of the Activation operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights.
"""
graph = ops.Graph()
with graph.as_default():
training.create_global_step(graph)
@ -637,22 +545,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
quantize.Quantize(graph, quant_delay=delay)
quantization_node_name = 'FakeQuantWithMinMaxVars'
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
expected_inputs = [
scope + '/weights_quant/' + ('AssignMinEma'
if use_ema else 'AssignMinLast'),
scope + '/weights_quant/' + ('AssignMaxEma'
if use_ema else 'AssignMaxLast'),
scope + '/mul_fold'
scope + '/weights_quant/' + 'AssignMinLast',
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
]
self._AssertInputOpsAre(weights_quant, expected_inputs)
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
if delay and use_ema else '/depthwise_Fold')
if delay else '/depthwise_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:

View File

@ -45,13 +45,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None, scope='test')
relu = nn_ops.relu6(inputs)
context = quantize._QuantizeContext(graph=graph, weight_bits=8,
weight_narrow_range=True,
activation_bits=8)
# Inserting a quantization op between two unconnected ops should fail with
# ValueError.
with self.assertRaises(ValueError) as err:
context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp')
quantize._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp')
self.assertEqual(
str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
@ -70,8 +67,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
activation_bits=8)
quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
add_quant = graph.get_operation_by_name('test/add_quant/' +
@ -94,8 +90,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
activation_bits=8)
quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
add_quant = graph.get_operation_by_name('test/add_quant/' +