Generalize quantization rewrite to not rely on names.
It should now work with most graphs regardless if they were built slim or not. PiperOrigin-RevId: 184889280
This commit is contained in:
parent
ef16b9cc9a
commit
190b918c8c
@ -75,7 +75,9 @@ py_library(
|
|||||||
":graph_matcher",
|
":graph_matcher",
|
||||||
":input_to_ops",
|
":input_to_ops",
|
||||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:layers",
|
"//tensorflow/python:layers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -83,6 +85,7 @@ py_library(
|
|||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:ops",
|
"//tensorflow/python:ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -162,7 +165,6 @@ py_test(
|
|||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
@ -174,7 +176,7 @@ py_library(
|
|||||||
srcs = ["python/quantize.py"],
|
srcs = ["python/quantize.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":graph_matcher",
|
||||||
":input_to_ops",
|
":input_to_ops",
|
||||||
":quant_ops",
|
":quant_ops",
|
||||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||||
|
@ -53,7 +53,7 @@ def LastValueQuantize(inputs,
|
|||||||
init_max=6.0,
|
init_max=6.0,
|
||||||
updates_collection=ops.GraphKeys.UPDATE_OPS,
|
updates_collection=ops.GraphKeys.UPDATE_OPS,
|
||||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||||
scope=None,
|
name_prefix='LastValueQuant',
|
||||||
reuse=None,
|
reuse=None,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
num_bits=8,
|
num_bits=8,
|
||||||
@ -73,7 +73,7 @@ def LastValueQuantize(inputs,
|
|||||||
computation.
|
computation.
|
||||||
vars_collection: (Optional) collection where to store variables for
|
vars_collection: (Optional) collection where to store variables for
|
||||||
quantization interval ends.
|
quantization interval ends.
|
||||||
scope: Optional scope for variable_scope.
|
name_prefix: name_prefix for created nodes.
|
||||||
reuse: whether or not the layer and its variables should be reused. To be
|
reuse: whether or not the layer and its variables should be reused. To be
|
||||||
able to reuse the layer scope must be given.
|
able to reuse the layer scope must be given.
|
||||||
is_training: Whether the op is applied to a training or eval graph.
|
is_training: Whether the op is applied to a training or eval graph.
|
||||||
@ -84,13 +84,13 @@ def LastValueQuantize(inputs,
|
|||||||
a tensor containing quantized values.
|
a tensor containing quantized values.
|
||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
scope, 'LastValueQuantize', values=[inputs], reuse=reuse):
|
None, default_name=name_prefix, values=[inputs], reuse=reuse):
|
||||||
input_shape = inputs.get_shape()
|
input_shape = inputs.get_shape()
|
||||||
input_dim = len(input_shape)
|
input_dim = len(input_shape)
|
||||||
if per_channel:
|
if per_channel:
|
||||||
# Only support quantizing 1-, 2- and 4-dimensional tensors.
|
# Only support quantizing 1-, 2- and 4-dimensional tensors.
|
||||||
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
|
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
|
||||||
' scope: %s' % (input_shape, scope))
|
' scope: %s' % (input_shape, name_prefix))
|
||||||
min_max_shape = [input_shape[-1]]
|
min_max_shape = [input_shape[-1]]
|
||||||
else:
|
else:
|
||||||
min_max_shape = []
|
min_max_shape = []
|
||||||
@ -165,7 +165,7 @@ def MovingAvgQuantize(inputs,
|
|||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
updates_collection=ops.GraphKeys.UPDATE_OPS,
|
updates_collection=ops.GraphKeys.UPDATE_OPS,
|
||||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||||
scope=None,
|
name_prefix='MovingAvgQuantize',
|
||||||
reuse=None,
|
reuse=None,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
num_bits=8,
|
num_bits=8,
|
||||||
@ -186,7 +186,7 @@ def MovingAvgQuantize(inputs,
|
|||||||
computation.
|
computation.
|
||||||
vars_collection: (Optional) collection where to store variables for
|
vars_collection: (Optional) collection where to store variables for
|
||||||
quantization interval ends.
|
quantization interval ends.
|
||||||
scope: Optional scope for variable_scope.
|
name_prefix: name_prefix for created nodes.
|
||||||
reuse: whether or not the layer and its variables should be reused. To be
|
reuse: whether or not the layer and its variables should be reused. To be
|
||||||
able to reuse the layer scope must be given.
|
able to reuse the layer scope must be given.
|
||||||
is_training: Whether the op is applied to a training or eval graph.
|
is_training: Whether the op is applied to a training or eval graph.
|
||||||
@ -197,13 +197,13 @@ def MovingAvgQuantize(inputs,
|
|||||||
a tensor containing quantized values.
|
a tensor containing quantized values.
|
||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse):
|
None, default_name=name_prefix, values=[inputs], reuse=reuse):
|
||||||
input_shape = inputs.get_shape()
|
input_shape = inputs.get_shape()
|
||||||
input_dim = len(input_shape)
|
input_dim = len(input_shape)
|
||||||
if per_channel:
|
if per_channel:
|
||||||
# Only support quantizing 1-, 2- and 4-dimensional tensors.
|
# Only support quantizing 1-, 2- and 4-dimensional tensors.
|
||||||
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
|
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
|
||||||
' scope: %s' % (input_shape, scope))
|
' scope: %s' % (input_shape, name_prefix))
|
||||||
min_max_shape = [input_shape[-1]]
|
min_max_shape = [input_shape[-1]]
|
||||||
else:
|
else:
|
||||||
min_max_shape = []
|
min_max_shape = []
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Logic to update a Tensorflow model graph with quantization operations."""
|
"""Logic to update a TensorFlow model graph with quantization operations."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from tensorflow.contrib import graph_editor
|
from tensorflow.contrib import graph_editor
|
||||||
from tensorflow.contrib.quantize.python import common
|
from tensorflow.contrib.quantize.python import graph_matcher
|
||||||
from tensorflow.contrib.quantize.python import input_to_ops
|
from tensorflow.contrib.quantize.python import input_to_ops
|
||||||
from tensorflow.contrib.quantize.python import quant_ops
|
from tensorflow.contrib.quantize.python import quant_ops
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -28,30 +28,29 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
# Operation types used to select operations of interest.
|
# Quantizable operation types that are supported by the quantization rewrite.
|
||||||
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
|
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
|
||||||
|
|
||||||
# Custom key for storing and retrieving update ops used by quantizing nodes.
|
# Activations that are supported by the quantization rewrite.
|
||||||
_UPDATE_QUANT_OPS = 'update_quant_ops'
|
_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'}
|
||||||
|
|
||||||
|
# Weight types that are supported by the quantization rewrite.
|
||||||
|
# TODO(suharshs): Add support for ResourceVariable.
|
||||||
|
_WEIGHT_TYPES = {'Variable', 'VariableV2'}
|
||||||
|
|
||||||
|
|
||||||
def Quantize(graph,
|
def Quantize(graph,
|
||||||
weight_bits=8,
|
weight_bits=8,
|
||||||
weight_narrow_range=False,
|
|
||||||
activation_bits=8,
|
activation_bits=8,
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
quant_delay=None,
|
quant_delay=None,
|
||||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||||
is_training=True,
|
is_training=True):
|
||||||
quantize_folded_weights_use_ema=False):
|
|
||||||
"""Updates graph with quantization operations.
|
"""Updates graph with quantization operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Graph to modify.
|
graph: Graph to modify.
|
||||||
weight_bits: Number of bits to use for quantizing weights.
|
weight_bits: Number of bits to use for quantizing weights.
|
||||||
weight_narrow_range: Whether to use a more efficient narrow range for
|
|
||||||
weights quantization. With weight_narrow_range true, the range is
|
|
||||||
[1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
|
|
||||||
activation_bits: Number of bits to use for quantizing activations.
|
activation_bits: Number of bits to use for quantizing activations.
|
||||||
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
||||||
quantization intervals for quantizing activations (see here about EMA:
|
quantization intervals for quantizing activations (see here about EMA:
|
||||||
@ -62,345 +61,274 @@ def Quantize(graph,
|
|||||||
vars_collection: (Optional) Collection where to store the variables for
|
vars_collection: (Optional) Collection where to store the variables for
|
||||||
quantization interval ends.
|
quantization interval ends.
|
||||||
is_training: (Optional) Whether quantizing training graph or eval graph.
|
is_training: (Optional) Whether quantizing training graph or eval graph.
|
||||||
quantize_folded_weights_use_ema: (Optional, default False) Whether to
|
|
||||||
quantize weights after batchnorm-folding with exponential average
|
|
||||||
quantization.
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: When quantization fails.
|
ValueError: When quantization fails.
|
||||||
"""
|
"""
|
||||||
context = _QuantizeContext(graph, weight_bits, weight_narrow_range,
|
input_to_ops_map = input_to_ops.InputToOps(graph)
|
||||||
activation_bits, ema_decay, quant_delay,
|
for layer_match in _FindLayersToQuantize(graph):
|
||||||
vars_collection, is_training,
|
# Quantize the weights.
|
||||||
quantize_folded_weights_use_ema)
|
context = _GetContextFromOp(layer_match.layer_op)
|
||||||
|
_InsertQuantOp(
|
||||||
graph_ops = graph.get_operations()
|
|
||||||
|
|
||||||
# Filter out backprop and summary related operations, leave only interesting
|
|
||||||
# op types.
|
|
||||||
def _IsInterestingOpWithWeights(op):
|
|
||||||
return (op.type in _QUANTIZABLE_TYPES and
|
|
||||||
not op.name.startswith(common.SKIPPED_PREFIXES))
|
|
||||||
|
|
||||||
for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)):
|
|
||||||
if op.name.endswith('/depthwise'):
|
|
||||||
# Separable convolution may consist of 2 convolution nodes. If so, skip
|
|
||||||
# .../depthwise and only quantize the top one.
|
|
||||||
separable_conv = context.GetOperationByNameDontThrow(
|
|
||||||
op.name[:-len('/depthwise')])
|
|
||||||
if separable_conv and separable_conv.type == 'Conv2D':
|
|
||||||
continue
|
|
||||||
# Quantize add ops that come after Conv2D or DepthwiseConv2dNative.
|
|
||||||
if op.type in ['Conv2D', 'DepthwiseConv2dNative']:
|
|
||||||
add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
|
|
||||||
if add_context_re is not None:
|
|
||||||
context.add_contexts.add(add_context_re.group(1))
|
|
||||||
if not op.name.endswith('_Fold'):
|
|
||||||
folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold')
|
|
||||||
# Do nothing if found, it will be quantized when it is iterated over.
|
|
||||||
if not folded_op:
|
|
||||||
context.QuantizeOpWithWeights(op, folded=False)
|
|
||||||
else:
|
|
||||||
context.QuantizeOpWithWeights(op, folded=True)
|
|
||||||
|
|
||||||
context.QuantizeAddContexts()
|
|
||||||
|
|
||||||
# Once all quantization ops have been inserted in the graph, collect update
|
|
||||||
# ops for their variables and modify the TF Slim update barrier (see
|
|
||||||
# https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py)
|
|
||||||
# to depend on them.
|
|
||||||
try:
|
|
||||||
update_barrier = graph.get_operation_by_name('update_barrier')
|
|
||||||
except KeyError:
|
|
||||||
# In evaluation graph, this barrier may not exist.
|
|
||||||
return None
|
|
||||||
update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS)
|
|
||||||
graph_editor.add_control_inputs(update_barrier, update_quant_ops)
|
|
||||||
|
|
||||||
|
|
||||||
class _QuantizeContext(object):
|
|
||||||
"""Context holds references needed for quantization."""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
graph,
|
|
||||||
weight_bits,
|
|
||||||
weight_narrow_range,
|
|
||||||
activation_bits,
|
|
||||||
ema_decay=0.999,
|
|
||||||
quant_delay=None,
|
|
||||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
|
||||||
is_training=True,
|
|
||||||
quantize_folded_weights_use_ema=False):
|
|
||||||
"""Initializes context to hold references needed for quantization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph: Graph to modify.
|
|
||||||
weight_bits: Number of bits to use for quantizing weights.
|
|
||||||
weight_narrow_range: Whether to use a more efficient narrow range for
|
|
||||||
weights quantization. With weight_narrow_range true, the range is
|
|
||||||
[1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
|
|
||||||
activation_bits: Number of bits to use for quantizing activations.
|
|
||||||
ema_decay: (Optional) Float, EMA decay parameter.
|
|
||||||
quant_delay: (Optional, default None) Int, count of global steps for which
|
|
||||||
to delay quantization. This helps weights stabilize at the start of
|
|
||||||
training.
|
|
||||||
vars_collection: (Optional) Collection where to store the variables for
|
|
||||||
quantization interval ends.
|
|
||||||
is_training: (Optional) Whether quantizing training or eval graph.
|
|
||||||
quantize_folded_weights_use_ema: (Optional, default False) Whether to
|
|
||||||
quantize weights after batchnorm-folding with exponential average
|
|
||||||
quantization.
|
|
||||||
"""
|
|
||||||
self.graph = graph
|
|
||||||
self.weight_bits = weight_bits
|
|
||||||
self.weight_narrow_range = weight_narrow_range
|
|
||||||
self.activation_bits = activation_bits
|
|
||||||
self.ema_decay = ema_decay
|
|
||||||
self.quant_delay = quant_delay
|
|
||||||
self.vars_collection = vars_collection
|
|
||||||
self.is_training = is_training
|
|
||||||
self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema
|
|
||||||
self.input_to_ops_map = input_to_ops.InputToOps(graph)
|
|
||||||
self.add_contexts = set()
|
|
||||||
|
|
||||||
def QuantizeAddContexts(self):
|
|
||||||
"""Quantizes all add ops in self.add_contexts."""
|
|
||||||
# Loop through sorted self.add_contexts so that op creation is
|
|
||||||
# deterministic. This is needed when using multiple worker replicas so that
|
|
||||||
# the ops can be initialized consistently.
|
|
||||||
for add_context in sorted(self.add_contexts):
|
|
||||||
add_op = self.GetOperationByNamesDontThrow([
|
|
||||||
add_context + '/Add', add_context + '/add'])
|
|
||||||
if add_op is not None:
|
|
||||||
self._InsertQuantOp(
|
|
||||||
add_context,
|
|
||||||
add_op,
|
|
||||||
self.input_to_ops_map.ConsumerOperations(add_op),
|
|
||||||
name='add_quant',
|
|
||||||
moving_avg=True,
|
|
||||||
bits=self.activation_bits,
|
|
||||||
narrow_range=False)
|
|
||||||
|
|
||||||
def QuantizeOpWithWeights(self, op, folded):
|
|
||||||
"""Quantizes around the specific operation with or without batch norm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op: Operation to quantize.
|
|
||||||
folded: Operation has been folded and needs special handling if True.
|
|
||||||
Raises:
|
|
||||||
ValueError: When quantization fails.
|
|
||||||
"""
|
|
||||||
# Op name component before the last slash will be used as context.
|
|
||||||
context = re.search(r'^(.*)/([^/]+)', op.name).group(1)
|
|
||||||
|
|
||||||
# Quantize weights.
|
|
||||||
if folded:
|
|
||||||
producer_op = self.graph.get_operation_by_name(context + '/mul_fold')
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
input_idx = next(i for i, v in enumerate(op.inputs)
|
|
||||||
if '/weights/' in v.name or
|
|
||||||
'/depthwise_weights' in v.name)
|
|
||||||
except StopIteration:
|
|
||||||
raise ValueError('No inputs to quantize for op: %s' % op)
|
|
||||||
producer_op = op.inputs[input_idx].op
|
|
||||||
|
|
||||||
# If batch norm is used, the folded weights depend on the batch std, hence
|
|
||||||
# it is sensible to use EMA during training to smooth out the noise. This is
|
|
||||||
# controlled by the flag quantize_folded_weights_use_ema. Its default is
|
|
||||||
# False for backward compatibility.
|
|
||||||
# If there is no batch norm, weights do not depend on the batch and using
|
|
||||||
# the latest value of min and max is more efficient.
|
|
||||||
weight_use_ema = folded and self.quantize_folded_weights_use_ema
|
|
||||||
self._InsertQuantOp(
|
|
||||||
context,
|
context,
|
||||||
producer_op, [op],
|
layer_match.weight_tensor.op, [layer_match.layer_op],
|
||||||
name='weights_quant',
|
name='weights_quant',
|
||||||
moving_avg=weight_use_ema,
|
moving_avg=False,
|
||||||
delay_requested=weight_use_ema,
|
bits=weight_bits,
|
||||||
bits=self.weight_bits,
|
ema_decay=ema_decay,
|
||||||
narrow_range=self.weight_narrow_range)
|
quant_delay=quant_delay,
|
||||||
|
is_training=is_training,
|
||||||
|
narrow_range=True,
|
||||||
|
vars_collection=vars_collection)
|
||||||
|
|
||||||
# Important: do not quantize biases here. During inference they are
|
# Quantize the activations.
|
||||||
# quantized to 32 bits, which is much finer than 8 bit quantization and
|
consumer_ops = input_to_ops_map.ConsumerOperations(
|
||||||
# depends on weight and input activation ranges.
|
layer_match.activation_op)
|
||||||
|
add_context = context
|
||||||
# Find activation and (optionally) Add operations to quantize.
|
if layer_match.bypass_op:
|
||||||
activation_op, add_op, add_context = self._GetReluAndAddOperations(context,
|
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
|
||||||
op)
|
_InsertQuantOp(
|
||||||
if add_op:
|
add_context,
|
||||||
original_context = context
|
layer_match.activation_op,
|
||||||
context = add_context
|
|
||||||
|
|
||||||
# Quantize activation outputs.
|
|
||||||
consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op)
|
|
||||||
self._InsertQuantOp(
|
|
||||||
context,
|
|
||||||
activation_op,
|
|
||||||
consumer_ops,
|
consumer_ops,
|
||||||
name='act_quant',
|
name='act_quant',
|
||||||
moving_avg=True,
|
moving_avg=True,
|
||||||
init_min=0.0,
|
init_min=0.0,
|
||||||
bits=self.activation_bits,
|
ema_decay=ema_decay,
|
||||||
narrow_range=False)
|
quant_delay=quant_delay,
|
||||||
|
bits=activation_bits,
|
||||||
|
vars_collection=vars_collection)
|
||||||
|
|
||||||
# When a bypass connection was found, also quantize Add op input.
|
# Quantize the inputs and output to the bypass (if it exists). The input to
|
||||||
if add_op:
|
# the bypass is the bias add, and the output is the activation.
|
||||||
def _QuantizeAddInput(add_input):
|
if layer_match.bypass_op is not None:
|
||||||
if folded:
|
_InsertQuantOp(
|
||||||
return add_input.op.name.endswith('/add_fold')
|
context,
|
||||||
else:
|
layer_match.bias_add_op, [layer_match.bypass_op],
|
||||||
return add_input.op.name.startswith(original_context + '/')
|
name='conv_quant',
|
||||||
|
moving_avg=True,
|
||||||
|
ema_decay=ema_decay,
|
||||||
|
quant_delay=quant_delay,
|
||||||
|
vars_collection=vars_collection,
|
||||||
|
bits=activation_bits)
|
||||||
|
_InsertQuantOp(
|
||||||
|
add_context,
|
||||||
|
layer_match.bypass_op,
|
||||||
|
input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
|
||||||
|
name='add_quant',
|
||||||
|
moving_avg=True,
|
||||||
|
bits=activation_bits)
|
||||||
|
|
||||||
for add_input in add_op.inputs:
|
|
||||||
if _QuantizeAddInput(add_input):
|
|
||||||
self._InsertQuantOp(
|
|
||||||
original_context,
|
|
||||||
add_input.op, [add_op],
|
|
||||||
name='conv_quant',
|
|
||||||
moving_avg=True,
|
|
||||||
bits=self.activation_bits,
|
|
||||||
narrow_range=False)
|
|
||||||
|
|
||||||
def _GetReluAndAddOperations(self, context, op):
|
def _FindLayersToQuantize(graph):
|
||||||
"""Looks up a Relu* and Add operations in given context.
|
"""Matches layers in graph to quantize.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: Context where to look for operations.
|
graph: Graph to perform match on.
|
||||||
op: Operation to quantize.
|
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
A triplet (Operation, Operation, string), the first element is an end
|
_LayerMatches.
|
||||||
point operation, the second is Add operation (optional), the third element
|
"""
|
||||||
is string context where the Add operation was found (optional).
|
input_pattern = graph_matcher.OpTypePattern('*')
|
||||||
|
weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES))
|
||||||
|
weight_pattern = graph_matcher.OpTypePattern(
|
||||||
|
'Identity', inputs=[weight_var_pattern])
|
||||||
|
|
||||||
Raises:
|
folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
|
||||||
ValueError: When operations cannot be found.
|
|
||||||
"""
|
|
||||||
activation_op = common.GetEndpointActivationOp(self.graph, context)
|
|
||||||
if activation_op:
|
|
||||||
return activation_op, None, None
|
|
||||||
|
|
||||||
if '/' in context:
|
# The weights inputs to the layer operation can either be from the Variable or
|
||||||
# If no activation op is there, look for them one level up.
|
# the folded weight (Mul).
|
||||||
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
|
layer_pattern = graph_matcher.OpTypePattern(
|
||||||
activation_op = common.GetEndpointActivationOp(self.graph, add_context)
|
'|'.join(_QUANTIZABLE_TYPES),
|
||||||
if not activation_op:
|
inputs=[
|
||||||
# Still no Relu, can happen on the top layer, just find the next node up,
|
input_pattern,
|
||||||
# make sure it is BiasAdd.
|
graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern])
|
||||||
consumers = [c for outp in op.outputs for c in outp.consumers()]
|
])
|
||||||
if len(consumers) != 1 or consumers[0].type != 'BiasAdd':
|
|
||||||
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
|
|
||||||
return consumers[0], None, None
|
|
||||||
if add_context:
|
|
||||||
add_op = self.GetOperationByNamesDontThrow([
|
|
||||||
add_context + '/Add', add_context + '/add'])
|
|
||||||
return activation_op, add_op, add_context
|
|
||||||
else:
|
|
||||||
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
|
|
||||||
|
|
||||||
def GetOperationByNameDontThrow(self, name):
|
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
|
||||||
"""Returns an Operation with the given name.
|
'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
|
||||||
|
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
|
||||||
|
'Add', inputs=[folded_bias_mul_pattern,
|
||||||
|
graph_matcher.OpTypePattern('*')])
|
||||||
|
folded_bias_add_pattern = graph_matcher.OpTypePattern(
|
||||||
|
'Add',
|
||||||
|
inputs=[
|
||||||
|
post_layer_op_correction_pattern,
|
||||||
|
graph_matcher.OpTypePattern('*')
|
||||||
|
])
|
||||||
|
|
||||||
Args:
|
bias_add_pattern = graph_matcher.OpTypePattern(
|
||||||
name: Name of Operation to return.
|
'Add|BiasAdd', inputs=[layer_pattern, '*'])
|
||||||
|
|
||||||
Returns:
|
# The bias can come from the bias add or the folded bias add.
|
||||||
The Operation with the given name. None if the name does not correspond to
|
bypass_pattern_a = graph_matcher.OpTypePattern(
|
||||||
any operation in the graph.
|
'Add',
|
||||||
"""
|
inputs=[
|
||||||
try:
|
graph_matcher.OneofPattern(
|
||||||
return self.graph.get_operation_by_name(name)
|
[bias_add_pattern, folded_bias_add_pattern]), '*'
|
||||||
except KeyError:
|
])
|
||||||
return None
|
bypass_pattern_b = graph_matcher.OpTypePattern(
|
||||||
|
'Add',
|
||||||
|
inputs=[
|
||||||
|
'*',
|
||||||
|
graph_matcher.OneofPattern(
|
||||||
|
[bias_add_pattern, folded_bias_add_pattern])
|
||||||
|
])
|
||||||
|
|
||||||
def GetOperationByNamesDontThrow(self, names):
|
# The input to the activation can come from bias add, fold bias add or the
|
||||||
"""Returns an Operation with one of the given names.
|
# bypasses.
|
||||||
|
activation_pattern = graph_matcher.OpTypePattern(
|
||||||
|
'|'.join(_ACTIVATION_TYPES),
|
||||||
|
inputs=[
|
||||||
|
graph_matcher.OneofPattern([
|
||||||
|
bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
|
||||||
|
bypass_pattern_b
|
||||||
|
])
|
||||||
|
])
|
||||||
|
|
||||||
Args:
|
layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
|
||||||
names: Names of Operation to return.
|
for match_result in layer_matcher.match_graph(graph):
|
||||||
|
layer_op = match_result.get_op(layer_pattern)
|
||||||
|
weight_tensor = match_result.get_tensor(weight_pattern)
|
||||||
|
if weight_tensor is None:
|
||||||
|
weight_tensor = match_result.get_tensor(folded_weight_pattern)
|
||||||
|
activation_op = match_result.get_op(activation_pattern)
|
||||||
|
bias_add_op = match_result.get_op(bias_add_pattern)
|
||||||
|
if bias_add_op is None:
|
||||||
|
bias_add_op = match_result.get_op(folded_bias_add_pattern)
|
||||||
|
bypass_op = match_result.get_op(bypass_pattern_a)
|
||||||
|
if bypass_op is None:
|
||||||
|
bypass_op = match_result.get_op(bypass_pattern_b)
|
||||||
|
yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
|
||||||
|
bias_add_op)
|
||||||
|
|
||||||
Returns:
|
|
||||||
The Operation with one of the given names. None if none of the names
|
|
||||||
corresponds to any operation in the graph.
|
|
||||||
"""
|
|
||||||
for name in names:
|
|
||||||
op = self.GetOperationByNameDontThrow(name)
|
|
||||||
if op is not None:
|
|
||||||
return op
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _InsertQuantOp(
|
class _LayerMatch(object):
|
||||||
self,
|
"""Contains all information related to a matched Layer."""
|
||||||
context,
|
|
||||||
producer,
|
|
||||||
consumers,
|
|
||||||
name,
|
|
||||||
moving_avg=True,
|
|
||||||
init_min=-6.0,
|
|
||||||
init_max=6.0,
|
|
||||||
delay_requested=True,
|
|
||||||
bits=8,
|
|
||||||
narrow_range=False,):
|
|
||||||
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
|
||||||
|
|
||||||
Args:
|
def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
|
||||||
context: Context where producer and consumer operations are nested.
|
bias_add_op):
|
||||||
producer: Producer operation of the pairs where quantization will be
|
self._layer_op = layer_op
|
||||||
inserted.
|
self._weight_tensor = weight_tensor
|
||||||
consumers: Consumer operations of the pairs.
|
self._activation_op = activation_op
|
||||||
name: Name for the new quantization op within the context.
|
self._bypass_op = bypass_op
|
||||||
moving_avg: Specifies whether to use exponential moving average or just
|
self._bias_add_op = bias_add_op
|
||||||
the last value seen.
|
|
||||||
init_min: Starting minimum value for the new quantization op.
|
@property
|
||||||
init_max: Starting maximum value for the new quantization op.
|
def layer_op(self):
|
||||||
delay_requested: If true, implement quantization delay where needed.
|
return self._layer_op
|
||||||
False value explicitly disables delay quantization everywhere.
|
|
||||||
bits: Number of bits to use for quantization, must be between 2 and 8.
|
@property
|
||||||
narrow_range: Whether to use the narrow quantization range
|
def weight_tensor(self):
|
||||||
|
return self._weight_tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_op(self):
|
||||||
|
return self._activation_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bypass_op(self):
|
||||||
|
return self._bypass_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bias_add_op(self):
|
||||||
|
return self._bias_add_op
|
||||||
|
|
||||||
|
|
||||||
|
def _InsertQuantOp(context,
|
||||||
|
producer,
|
||||||
|
consumers,
|
||||||
|
name,
|
||||||
|
moving_avg=True,
|
||||||
|
init_min=-6.0,
|
||||||
|
init_max=6.0,
|
||||||
|
bits=8,
|
||||||
|
ema_decay=0.999,
|
||||||
|
quant_delay=None,
|
||||||
|
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||||
|
is_training=True,
|
||||||
|
narrow_range=False):
|
||||||
|
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context w,here producer and consumer operations are nested.
|
||||||
|
producer: Producer operation of the pairs where quantization will be
|
||||||
|
inserted.
|
||||||
|
consumers: Consumer operations of the pairs.
|
||||||
|
name: Name for the new quantization op within the context.
|
||||||
|
moving_avg: Specifies whether to use exponential moving average or just
|
||||||
|
the last value seen.
|
||||||
|
init_min: Starting minimum value for the new quantization op.
|
||||||
|
init_max: Starting maximum value for the new quantization op.
|
||||||
|
bits: Number of bits to use for quantization, must be between 2 and 8.
|
||||||
|
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
||||||
|
quantization intervals for quantizing activations (see here about EMA:
|
||||||
|
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
|
||||||
|
quant_delay: (Optional, default None) Int, count of global steps for which
|
||||||
|
to delay quantization. This helps weights stabilize at the start of
|
||||||
|
training.
|
||||||
|
vars_collection: (Optional) Collection where to store the variables for
|
||||||
|
quantization interval ends.
|
||||||
|
is_training: (Optional) Whether quantizing training graph or eval graph.
|
||||||
|
narrow_range: Whether to use the narrow quantization range
|
||||||
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
|
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: When producer operation is not directly connected to the
|
ValueError: When producer operation is not directly connected to the
|
||||||
consumer operation.
|
consumer operation.
|
||||||
"""
|
"""
|
||||||
scope = context + '/' + name
|
name_prefix = _AddContextToName(context, name)
|
||||||
inputs = producer.outputs[0]
|
inputs = producer.outputs[0]
|
||||||
if moving_avg:
|
if moving_avg:
|
||||||
quant = (quant_ops.MovingAvgQuantize(
|
quant = (
|
||||||
inputs,
|
quant_ops.MovingAvgQuantize(
|
||||||
init_min=init_min,
|
inputs,
|
||||||
init_max=init_max,
|
init_min=init_min,
|
||||||
ema_decay=self.ema_decay,
|
init_max=init_max,
|
||||||
is_training=self.is_training,
|
ema_decay=ema_decay,
|
||||||
num_bits=bits,
|
is_training=is_training,
|
||||||
narrow_range=narrow_range,
|
num_bits=bits,
|
||||||
updates_collection=_UPDATE_QUANT_OPS,
|
narrow_range=narrow_range,
|
||||||
vars_collection=self.vars_collection,
|
vars_collection=vars_collection,
|
||||||
scope=scope))
|
name_prefix=name_prefix))
|
||||||
else:
|
else:
|
||||||
quant = (quant_ops.LastValueQuantize(
|
quant = (
|
||||||
inputs,
|
quant_ops.LastValueQuantize(
|
||||||
init_min=init_min,
|
inputs,
|
||||||
init_max=init_max,
|
init_min=init_min,
|
||||||
is_training=self.is_training,
|
init_max=init_max,
|
||||||
num_bits=bits,
|
is_training=is_training,
|
||||||
narrow_range=narrow_range,
|
num_bits=bits,
|
||||||
updates_collection=_UPDATE_QUANT_OPS,
|
narrow_range=narrow_range,
|
||||||
vars_collection=self.vars_collection,
|
vars_collection=vars_collection,
|
||||||
scope=scope))
|
name_prefix=name_prefix))
|
||||||
|
|
||||||
if delay_requested and self.quant_delay and self.quant_delay > 0:
|
if quant_delay and quant_delay > 0:
|
||||||
activate_quant = math_ops.greater_equal(
|
activate_quant = math_ops.greater_equal(
|
||||||
training_util.get_or_create_global_step(),
|
training_util.get_or_create_global_step(),
|
||||||
self.quant_delay,
|
quant_delay,
|
||||||
name=scope + '/activate_quant')
|
name=name_prefix + '/activate_quant')
|
||||||
quant = control_flow_ops.cond(
|
quant = control_flow_ops.cond(
|
||||||
activate_quant,
|
activate_quant,
|
||||||
lambda: quant,
|
lambda: quant,
|
||||||
lambda: inputs,
|
lambda: inputs,
|
||||||
name=scope + '/delayed_quant')
|
name=name_prefix + '/delayed_quant')
|
||||||
|
|
||||||
nodes_modified_count = graph_editor.reroute_ts(
|
nodes_modified_count = graph_editor.reroute_ts(
|
||||||
[quant], [inputs], can_modify=consumers)
|
[quant], [inputs], can_modify=consumers)
|
||||||
if nodes_modified_count != len(consumers):
|
if nodes_modified_count != len(consumers):
|
||||||
raise ValueError('Some inputs not quantized for ops: [%s]' %
|
raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
|
||||||
', '.join([consumer.name for consumer in consumers]))
|
[consumer.name for consumer in consumers]))
|
||||||
|
|
||||||
|
|
||||||
|
def _GetContextFromOp(op):
|
||||||
|
"""Gets the root context name from the op name."""
|
||||||
|
context_re = re.search(r'^(.*)/([^/]+)', op.name)
|
||||||
|
if context_re:
|
||||||
|
return context_re.group(1)
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def _AddContextToName(context, name):
|
||||||
|
"""Adds the context to the name if it exists."""
|
||||||
|
if not context:
|
||||||
|
return name
|
||||||
|
return context + '/' + name
|
||||||
|
@ -101,7 +101,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
|
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
|
||||||
]
|
]
|
||||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||||
output_op_name = scope + '/Conv2D'
|
output_op_name = (
|
||||||
|
scope + '/weights_quant/delayed_quant/Switch_1'
|
||||||
|
if delay else scope + '/Conv2D')
|
||||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||||
|
|
||||||
if with_bypass:
|
if with_bypass:
|
||||||
@ -176,7 +178,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
|
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
|
||||||
]
|
]
|
||||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||||
output_op_name = scope + '/MatMul'
|
output_op_name = (
|
||||||
|
scope + '/weights_quant/delayed_quant/Switch_1'
|
||||||
|
if delay else scope + '/MatMul')
|
||||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||||
|
|
||||||
if with_bypass:
|
if with_bypass:
|
||||||
@ -252,7 +256,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
scope + '/depthwise_weights/read'
|
scope + '/depthwise_weights/read'
|
||||||
]
|
]
|
||||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||||
output_op_name = scope + '/depthwise'
|
output_op_name = (
|
||||||
|
scope + '/weights_quant/delayed_quant/Switch_1'
|
||||||
|
if delay else scope + '/depthwise')
|
||||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||||
|
|
||||||
if with_bypass:
|
if with_bypass:
|
||||||
@ -316,6 +322,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
for params in parameters_list:
|
for params in parameters_list:
|
||||||
test_fn(params[0], params[1], params[2], params[3], params[4])
|
test_fn(params[0], params[1], params[2], params[3], params[4])
|
||||||
|
|
||||||
|
def testQuantize_Conv2dWithBatchNorm(self):
|
||||||
|
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
|
||||||
|
|
||||||
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
||||||
with_bypass, delay, fused_batch_norm):
|
with_bypass, delay, fused_batch_norm):
|
||||||
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
||||||
@ -329,39 +338,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
delay: Int (optional), delay in number of steps until quantization starts.
|
delay: Int (optional), delay in number of steps until quantization starts.
|
||||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||||
"""
|
"""
|
||||||
self._testQuantize_Conv2dWithBatchNorm(
|
|
||||||
activation,
|
|
||||||
activation_op_name,
|
|
||||||
with_bypass,
|
|
||||||
delay,
|
|
||||||
fused_batch_norm,
|
|
||||||
use_ema=True)
|
|
||||||
self._testQuantize_Conv2dWithBatchNorm(
|
|
||||||
activation,
|
|
||||||
activation_op_name,
|
|
||||||
with_bypass,
|
|
||||||
delay,
|
|
||||||
fused_batch_norm,
|
|
||||||
use_ema=False)
|
|
||||||
|
|
||||||
def testQuantize_Conv2dWithBatchNorm(self):
|
|
||||||
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
|
|
||||||
|
|
||||||
def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
|
||||||
with_bypass, delay, fused_batch_norm,
|
|
||||||
use_ema):
|
|
||||||
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activation: Callable that returns an Operation, a factory method for the
|
|
||||||
Activation.
|
|
||||||
activation_op_name: String, name of the Activation operation.
|
|
||||||
with_bypass: Bool, when true there is an extra connection added from
|
|
||||||
inputs to just before Activation.
|
|
||||||
delay: Int (optional), delay in number of steps until quantization starts.
|
|
||||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
|
||||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
|
||||||
"""
|
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
training.create_global_step(graph)
|
training.create_global_step(graph)
|
||||||
@ -394,23 +370,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
fold_batch_norms.FoldBatchNorms(graph)
|
fold_batch_norms.FoldBatchNorms(graph)
|
||||||
|
|
||||||
quantize.Quantize(
|
quantize.Quantize(graph, quant_delay=delay)
|
||||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
|
||||||
|
|
||||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||||
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
||||||
quantization_node_name)
|
quantization_node_name)
|
||||||
self.assertEqual(weights_quant.type, quantization_node_name)
|
self.assertEqual(weights_quant.type, quantization_node_name)
|
||||||
expected_inputs = [
|
expected_inputs = [
|
||||||
scope + '/weights_quant/' + ('AssignMinEma'
|
scope + '/weights_quant/' + 'AssignMinLast',
|
||||||
if use_ema else 'AssignMinLast'),
|
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
|
||||||
scope + '/weights_quant/' + ('AssignMaxEma'
|
|
||||||
if use_ema else 'AssignMaxLast'),
|
|
||||||
scope + '/mul_fold'
|
|
||||||
]
|
]
|
||||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||||
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
||||||
if (delay and use_ema) else '/Conv2D_Fold')
|
if delay else '/Conv2D_Fold')
|
||||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||||
|
|
||||||
if with_bypass:
|
if with_bypass:
|
||||||
@ -438,6 +410,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
if delay else 'control_dependency')
|
if delay else 'control_dependency')
|
||||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||||
|
|
||||||
|
def testQuantize_FCWithBatchNorm(self):
|
||||||
|
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
|
||||||
|
|
||||||
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
||||||
with_bypass, delay, fused_batch_norm):
|
with_bypass, delay, fused_batch_norm):
|
||||||
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
||||||
@ -451,39 +426,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
delay: Int (optional), delay in number of steps until quantization starts.
|
delay: Int (optional), delay in number of steps until quantization starts.
|
||||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||||
"""
|
"""
|
||||||
self._testQuantize_FCWithBatchNorm(
|
|
||||||
activation,
|
|
||||||
activation_op_name,
|
|
||||||
with_bypass,
|
|
||||||
delay,
|
|
||||||
fused_batch_norm,
|
|
||||||
use_ema=True)
|
|
||||||
self._testQuantize_FCWithBatchNorm(
|
|
||||||
activation,
|
|
||||||
activation_op_name,
|
|
||||||
with_bypass,
|
|
||||||
delay,
|
|
||||||
fused_batch_norm,
|
|
||||||
use_ema=False)
|
|
||||||
|
|
||||||
def testQuantize_FCWithBatchNorm(self):
|
|
||||||
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
|
|
||||||
|
|
||||||
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
|
||||||
with_bypass, delay, fused_batch_norm,
|
|
||||||
use_ema):
|
|
||||||
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activation: Callable that returns an Operation, a factory method for the
|
|
||||||
Activation.
|
|
||||||
activation_op_name: String, name of the Activation operation.
|
|
||||||
with_bypass: Bool, when true there is an extra connection added from
|
|
||||||
inputs to just before Activation.
|
|
||||||
delay: Int (optional), delay in number of steps until quantization starts.
|
|
||||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
|
||||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
|
||||||
"""
|
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
training.create_global_step(graph)
|
training.create_global_step(graph)
|
||||||
@ -513,23 +455,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
fold_batch_norms.FoldBatchNorms(graph)
|
fold_batch_norms.FoldBatchNorms(graph)
|
||||||
|
|
||||||
quantize.Quantize(
|
quantize.Quantize(graph, quant_delay=delay)
|
||||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
|
||||||
|
|
||||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||||
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
||||||
quantization_node_name)
|
quantization_node_name)
|
||||||
self.assertEqual(weights_quant.type, quantization_node_name)
|
self.assertEqual(weights_quant.type, quantization_node_name)
|
||||||
expected_inputs = [
|
expected_inputs = [
|
||||||
scope + '/weights_quant/' + ('AssignMinEma'
|
scope + '/weights_quant/' + 'AssignMinLast',
|
||||||
if use_ema else 'AssignMinLast'),
|
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
|
||||||
scope + '/weights_quant/' + ('AssignMaxEma'
|
|
||||||
if use_ema else 'AssignMaxLast'),
|
|
||||||
scope + '/mul_fold'
|
|
||||||
]
|
]
|
||||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||||
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
||||||
if delay and use_ema else '/MatMul_Fold')
|
if delay else '/MatMul_Fold')
|
||||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||||
|
|
||||||
if with_bypass:
|
if with_bypass:
|
||||||
@ -557,6 +495,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
if delay else 'control_dependency')
|
if delay else 'control_dependency')
|
||||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||||
|
|
||||||
|
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
|
||||||
|
self._RunBatchNormTestOverParameters(
|
||||||
|
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
|
||||||
|
|
||||||
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
|
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
|
||||||
self, activation, activation_op_name, with_bypass, delay,
|
self, activation, activation_op_name, with_bypass, delay,
|
||||||
fused_batch_norm):
|
fused_batch_norm):
|
||||||
@ -571,40 +513,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
delay: Int (optional), delay in number of steps until quantization starts.
|
delay: Int (optional), delay in number of steps until quantization starts.
|
||||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||||
"""
|
"""
|
||||||
self._testQuantize_DepthwiseConv2dWithBatchNorm(
|
|
||||||
activation,
|
|
||||||
activation_op_name,
|
|
||||||
with_bypass,
|
|
||||||
delay,
|
|
||||||
fused_batch_norm,
|
|
||||||
use_ema=True)
|
|
||||||
self._testQuantize_DepthwiseConv2dWithBatchNorm(
|
|
||||||
activation,
|
|
||||||
activation_op_name,
|
|
||||||
with_bypass,
|
|
||||||
delay,
|
|
||||||
fused_batch_norm,
|
|
||||||
use_ema=False)
|
|
||||||
|
|
||||||
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
|
|
||||||
self._RunBatchNormTestOverParameters(
|
|
||||||
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
|
|
||||||
|
|
||||||
def _testQuantize_DepthwiseConv2dWithBatchNorm(
|
|
||||||
self, activation, activation_op_name, with_bypass, delay,
|
|
||||||
fused_batch_norm, use_ema):
|
|
||||||
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activation: Callable that returns an Operation, a factory method for the
|
|
||||||
Activation.
|
|
||||||
activation_op_name: String, name of the Activation operation.
|
|
||||||
with_bypass: Bool, when true there is an extra connection added from
|
|
||||||
inputs to just before Activation.
|
|
||||||
delay: Int (optional), delay in number of steps until quantization starts.
|
|
||||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
|
||||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
|
||||||
"""
|
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
training.create_global_step(graph)
|
training.create_global_step(graph)
|
||||||
@ -637,22 +545,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
fold_batch_norms.FoldBatchNorms(graph)
|
fold_batch_norms.FoldBatchNorms(graph)
|
||||||
|
|
||||||
quantize.Quantize(
|
quantize.Quantize(graph, quant_delay=delay)
|
||||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
|
||||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||||
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
||||||
quantization_node_name)
|
quantization_node_name)
|
||||||
self.assertEqual(weights_quant.type, quantization_node_name)
|
self.assertEqual(weights_quant.type, quantization_node_name)
|
||||||
expected_inputs = [
|
expected_inputs = [
|
||||||
scope + '/weights_quant/' + ('AssignMinEma'
|
scope + '/weights_quant/' + 'AssignMinLast',
|
||||||
if use_ema else 'AssignMinLast'),
|
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
|
||||||
scope + '/weights_quant/' + ('AssignMaxEma'
|
|
||||||
if use_ema else 'AssignMaxLast'),
|
|
||||||
scope + '/mul_fold'
|
|
||||||
]
|
]
|
||||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||||
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
||||||
if delay and use_ema else '/depthwise_Fold')
|
if delay else '/depthwise_Fold')
|
||||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||||
|
|
||||||
if with_bypass:
|
if with_bypass:
|
||||||
|
@ -45,13 +45,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
activation_fn=None, scope='test')
|
activation_fn=None, scope='test')
|
||||||
relu = nn_ops.relu6(inputs)
|
relu = nn_ops.relu6(inputs)
|
||||||
|
|
||||||
context = quantize._QuantizeContext(graph=graph, weight_bits=8,
|
|
||||||
weight_narrow_range=True,
|
|
||||||
activation_bits=8)
|
|
||||||
# Inserting a quantization op between two unconnected ops should fail with
|
# Inserting a quantization op between two unconnected ops should fail with
|
||||||
# ValueError.
|
# ValueError.
|
||||||
with self.assertRaises(ValueError) as err:
|
with self.assertRaises(ValueError) as err:
|
||||||
context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp')
|
quantize._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
|
str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
|
||||||
|
|
||||||
@ -70,8 +67,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.control_dependencies([update_barrier]):
|
with ops.control_dependencies([update_barrier]):
|
||||||
array_ops.identity(node, name='control_dependency')
|
array_ops.identity(node, name='control_dependency')
|
||||||
|
|
||||||
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
|
quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8)
|
||||||
activation_bits=8)
|
|
||||||
|
|
||||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||||
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
||||||
@ -94,8 +90,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.control_dependencies([update_barrier]):
|
with ops.control_dependencies([update_barrier]):
|
||||||
array_ops.identity(node, name='control_dependency')
|
array_ops.identity(node, name='control_dependency')
|
||||||
|
|
||||||
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
|
quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8)
|
||||||
activation_bits=8)
|
|
||||||
|
|
||||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||||
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
||||||
|
Loading…
x
Reference in New Issue
Block a user