Generalize quantization rewrite to not rely on names.
It should now work with most graphs regardless if they were built slim or not. PiperOrigin-RevId: 184889280
This commit is contained in:
parent
ef16b9cc9a
commit
190b918c8c
@ -75,7 +75,9 @@ py_library(
|
||||
":graph_matcher",
|
||||
":input_to_ops",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -83,6 +85,7 @@ py_library(
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
@ -162,7 +165,6 @@ py_test(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:variables",
|
||||
@ -174,7 +176,7 @@ py_library(
|
||||
srcs = ["python/quantize.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":common",
|
||||
":graph_matcher",
|
||||
":input_to_ops",
|
||||
":quant_ops",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
|
@ -53,7 +53,7 @@ def LastValueQuantize(inputs,
|
||||
init_max=6.0,
|
||||
updates_collection=ops.GraphKeys.UPDATE_OPS,
|
||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
scope=None,
|
||||
name_prefix='LastValueQuant',
|
||||
reuse=None,
|
||||
is_training=True,
|
||||
num_bits=8,
|
||||
@ -73,7 +73,7 @@ def LastValueQuantize(inputs,
|
||||
computation.
|
||||
vars_collection: (Optional) collection where to store variables for
|
||||
quantization interval ends.
|
||||
scope: Optional scope for variable_scope.
|
||||
name_prefix: name_prefix for created nodes.
|
||||
reuse: whether or not the layer and its variables should be reused. To be
|
||||
able to reuse the layer scope must be given.
|
||||
is_training: Whether the op is applied to a training or eval graph.
|
||||
@ -84,13 +84,13 @@ def LastValueQuantize(inputs,
|
||||
a tensor containing quantized values.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
scope, 'LastValueQuantize', values=[inputs], reuse=reuse):
|
||||
None, default_name=name_prefix, values=[inputs], reuse=reuse):
|
||||
input_shape = inputs.get_shape()
|
||||
input_dim = len(input_shape)
|
||||
if per_channel:
|
||||
# Only support quantizing 1-, 2- and 4-dimensional tensors.
|
||||
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
|
||||
' scope: %s' % (input_shape, scope))
|
||||
' scope: %s' % (input_shape, name_prefix))
|
||||
min_max_shape = [input_shape[-1]]
|
||||
else:
|
||||
min_max_shape = []
|
||||
@ -165,7 +165,7 @@ def MovingAvgQuantize(inputs,
|
||||
ema_decay=0.999,
|
||||
updates_collection=ops.GraphKeys.UPDATE_OPS,
|
||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
scope=None,
|
||||
name_prefix='MovingAvgQuantize',
|
||||
reuse=None,
|
||||
is_training=True,
|
||||
num_bits=8,
|
||||
@ -186,7 +186,7 @@ def MovingAvgQuantize(inputs,
|
||||
computation.
|
||||
vars_collection: (Optional) collection where to store variables for
|
||||
quantization interval ends.
|
||||
scope: Optional scope for variable_scope.
|
||||
name_prefix: name_prefix for created nodes.
|
||||
reuse: whether or not the layer and its variables should be reused. To be
|
||||
able to reuse the layer scope must be given.
|
||||
is_training: Whether the op is applied to a training or eval graph.
|
||||
@ -197,13 +197,13 @@ def MovingAvgQuantize(inputs,
|
||||
a tensor containing quantized values.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse):
|
||||
None, default_name=name_prefix, values=[inputs], reuse=reuse):
|
||||
input_shape = inputs.get_shape()
|
||||
input_dim = len(input_shape)
|
||||
if per_channel:
|
||||
# Only support quantizing 1-, 2- and 4-dimensional tensors.
|
||||
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
|
||||
' scope: %s' % (input_shape, scope))
|
||||
' scope: %s' % (input_shape, name_prefix))
|
||||
min_max_shape = [input_shape[-1]]
|
||||
else:
|
||||
min_max_shape = []
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Logic to update a Tensorflow model graph with quantization operations."""
|
||||
"""Logic to update a TensorFlow model graph with quantization operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import re
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib.quantize.python import common
|
||||
from tensorflow.contrib.quantize.python import graph_matcher
|
||||
from tensorflow.contrib.quantize.python import input_to_ops
|
||||
from tensorflow.contrib.quantize.python import quant_ops
|
||||
from tensorflow.python.framework import ops
|
||||
@ -28,30 +28,29 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
# Operation types used to select operations of interest.
|
||||
# Quantizable operation types that are supported by the quantization rewrite.
|
||||
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
|
||||
|
||||
# Custom key for storing and retrieving update ops used by quantizing nodes.
|
||||
_UPDATE_QUANT_OPS = 'update_quant_ops'
|
||||
# Activations that are supported by the quantization rewrite.
|
||||
_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'}
|
||||
|
||||
# Weight types that are supported by the quantization rewrite.
|
||||
# TODO(suharshs): Add support for ResourceVariable.
|
||||
_WEIGHT_TYPES = {'Variable', 'VariableV2'}
|
||||
|
||||
|
||||
def Quantize(graph,
|
||||
weight_bits=8,
|
||||
weight_narrow_range=False,
|
||||
activation_bits=8,
|
||||
ema_decay=0.999,
|
||||
quant_delay=None,
|
||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
is_training=True,
|
||||
quantize_folded_weights_use_ema=False):
|
||||
is_training=True):
|
||||
"""Updates graph with quantization operations.
|
||||
|
||||
Args:
|
||||
graph: Graph to modify.
|
||||
weight_bits: Number of bits to use for quantizing weights.
|
||||
weight_narrow_range: Whether to use a more efficient narrow range for
|
||||
weights quantization. With weight_narrow_range true, the range is
|
||||
[1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
|
||||
activation_bits: Number of bits to use for quantizing activations.
|
||||
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
||||
quantization intervals for quantizing activations (see here about EMA:
|
||||
@ -62,345 +61,274 @@ def Quantize(graph,
|
||||
vars_collection: (Optional) Collection where to store the variables for
|
||||
quantization interval ends.
|
||||
is_training: (Optional) Whether quantizing training graph or eval graph.
|
||||
quantize_folded_weights_use_ema: (Optional, default False) Whether to
|
||||
quantize weights after batchnorm-folding with exponential average
|
||||
quantization.
|
||||
Raises:
|
||||
ValueError: When quantization fails.
|
||||
"""
|
||||
context = _QuantizeContext(graph, weight_bits, weight_narrow_range,
|
||||
activation_bits, ema_decay, quant_delay,
|
||||
vars_collection, is_training,
|
||||
quantize_folded_weights_use_ema)
|
||||
|
||||
graph_ops = graph.get_operations()
|
||||
|
||||
# Filter out backprop and summary related operations, leave only interesting
|
||||
# op types.
|
||||
def _IsInterestingOpWithWeights(op):
|
||||
return (op.type in _QUANTIZABLE_TYPES and
|
||||
not op.name.startswith(common.SKIPPED_PREFIXES))
|
||||
|
||||
for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)):
|
||||
if op.name.endswith('/depthwise'):
|
||||
# Separable convolution may consist of 2 convolution nodes. If so, skip
|
||||
# .../depthwise and only quantize the top one.
|
||||
separable_conv = context.GetOperationByNameDontThrow(
|
||||
op.name[:-len('/depthwise')])
|
||||
if separable_conv and separable_conv.type == 'Conv2D':
|
||||
continue
|
||||
# Quantize add ops that come after Conv2D or DepthwiseConv2dNative.
|
||||
if op.type in ['Conv2D', 'DepthwiseConv2dNative']:
|
||||
add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
|
||||
if add_context_re is not None:
|
||||
context.add_contexts.add(add_context_re.group(1))
|
||||
if not op.name.endswith('_Fold'):
|
||||
folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold')
|
||||
# Do nothing if found, it will be quantized when it is iterated over.
|
||||
if not folded_op:
|
||||
context.QuantizeOpWithWeights(op, folded=False)
|
||||
else:
|
||||
context.QuantizeOpWithWeights(op, folded=True)
|
||||
|
||||
context.QuantizeAddContexts()
|
||||
|
||||
# Once all quantization ops have been inserted in the graph, collect update
|
||||
# ops for their variables and modify the TF Slim update barrier (see
|
||||
# https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py)
|
||||
# to depend on them.
|
||||
try:
|
||||
update_barrier = graph.get_operation_by_name('update_barrier')
|
||||
except KeyError:
|
||||
# In evaluation graph, this barrier may not exist.
|
||||
return None
|
||||
update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS)
|
||||
graph_editor.add_control_inputs(update_barrier, update_quant_ops)
|
||||
|
||||
|
||||
class _QuantizeContext(object):
|
||||
"""Context holds references needed for quantization."""
|
||||
|
||||
def __init__(self,
|
||||
graph,
|
||||
weight_bits,
|
||||
weight_narrow_range,
|
||||
activation_bits,
|
||||
ema_decay=0.999,
|
||||
quant_delay=None,
|
||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
is_training=True,
|
||||
quantize_folded_weights_use_ema=False):
|
||||
"""Initializes context to hold references needed for quantization.
|
||||
|
||||
Args:
|
||||
graph: Graph to modify.
|
||||
weight_bits: Number of bits to use for quantizing weights.
|
||||
weight_narrow_range: Whether to use a more efficient narrow range for
|
||||
weights quantization. With weight_narrow_range true, the range is
|
||||
[1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1].
|
||||
activation_bits: Number of bits to use for quantizing activations.
|
||||
ema_decay: (Optional) Float, EMA decay parameter.
|
||||
quant_delay: (Optional, default None) Int, count of global steps for which
|
||||
to delay quantization. This helps weights stabilize at the start of
|
||||
training.
|
||||
vars_collection: (Optional) Collection where to store the variables for
|
||||
quantization interval ends.
|
||||
is_training: (Optional) Whether quantizing training or eval graph.
|
||||
quantize_folded_weights_use_ema: (Optional, default False) Whether to
|
||||
quantize weights after batchnorm-folding with exponential average
|
||||
quantization.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.weight_bits = weight_bits
|
||||
self.weight_narrow_range = weight_narrow_range
|
||||
self.activation_bits = activation_bits
|
||||
self.ema_decay = ema_decay
|
||||
self.quant_delay = quant_delay
|
||||
self.vars_collection = vars_collection
|
||||
self.is_training = is_training
|
||||
self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema
|
||||
self.input_to_ops_map = input_to_ops.InputToOps(graph)
|
||||
self.add_contexts = set()
|
||||
|
||||
def QuantizeAddContexts(self):
|
||||
"""Quantizes all add ops in self.add_contexts."""
|
||||
# Loop through sorted self.add_contexts so that op creation is
|
||||
# deterministic. This is needed when using multiple worker replicas so that
|
||||
# the ops can be initialized consistently.
|
||||
for add_context in sorted(self.add_contexts):
|
||||
add_op = self.GetOperationByNamesDontThrow([
|
||||
add_context + '/Add', add_context + '/add'])
|
||||
if add_op is not None:
|
||||
self._InsertQuantOp(
|
||||
add_context,
|
||||
add_op,
|
||||
self.input_to_ops_map.ConsumerOperations(add_op),
|
||||
name='add_quant',
|
||||
moving_avg=True,
|
||||
bits=self.activation_bits,
|
||||
narrow_range=False)
|
||||
|
||||
def QuantizeOpWithWeights(self, op, folded):
|
||||
"""Quantizes around the specific operation with or without batch norm.
|
||||
|
||||
Args:
|
||||
op: Operation to quantize.
|
||||
folded: Operation has been folded and needs special handling if True.
|
||||
Raises:
|
||||
ValueError: When quantization fails.
|
||||
"""
|
||||
# Op name component before the last slash will be used as context.
|
||||
context = re.search(r'^(.*)/([^/]+)', op.name).group(1)
|
||||
|
||||
# Quantize weights.
|
||||
if folded:
|
||||
producer_op = self.graph.get_operation_by_name(context + '/mul_fold')
|
||||
else:
|
||||
try:
|
||||
input_idx = next(i for i, v in enumerate(op.inputs)
|
||||
if '/weights/' in v.name or
|
||||
'/depthwise_weights' in v.name)
|
||||
except StopIteration:
|
||||
raise ValueError('No inputs to quantize for op: %s' % op)
|
||||
producer_op = op.inputs[input_idx].op
|
||||
|
||||
# If batch norm is used, the folded weights depend on the batch std, hence
|
||||
# it is sensible to use EMA during training to smooth out the noise. This is
|
||||
# controlled by the flag quantize_folded_weights_use_ema. Its default is
|
||||
# False for backward compatibility.
|
||||
# If there is no batch norm, weights do not depend on the batch and using
|
||||
# the latest value of min and max is more efficient.
|
||||
weight_use_ema = folded and self.quantize_folded_weights_use_ema
|
||||
self._InsertQuantOp(
|
||||
input_to_ops_map = input_to_ops.InputToOps(graph)
|
||||
for layer_match in _FindLayersToQuantize(graph):
|
||||
# Quantize the weights.
|
||||
context = _GetContextFromOp(layer_match.layer_op)
|
||||
_InsertQuantOp(
|
||||
context,
|
||||
producer_op, [op],
|
||||
layer_match.weight_tensor.op, [layer_match.layer_op],
|
||||
name='weights_quant',
|
||||
moving_avg=weight_use_ema,
|
||||
delay_requested=weight_use_ema,
|
||||
bits=self.weight_bits,
|
||||
narrow_range=self.weight_narrow_range)
|
||||
moving_avg=False,
|
||||
bits=weight_bits,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
is_training=is_training,
|
||||
narrow_range=True,
|
||||
vars_collection=vars_collection)
|
||||
|
||||
# Important: do not quantize biases here. During inference they are
|
||||
# quantized to 32 bits, which is much finer than 8 bit quantization and
|
||||
# depends on weight and input activation ranges.
|
||||
|
||||
# Find activation and (optionally) Add operations to quantize.
|
||||
activation_op, add_op, add_context = self._GetReluAndAddOperations(context,
|
||||
op)
|
||||
if add_op:
|
||||
original_context = context
|
||||
context = add_context
|
||||
|
||||
# Quantize activation outputs.
|
||||
consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op)
|
||||
self._InsertQuantOp(
|
||||
context,
|
||||
activation_op,
|
||||
# Quantize the activations.
|
||||
consumer_ops = input_to_ops_map.ConsumerOperations(
|
||||
layer_match.activation_op)
|
||||
add_context = context
|
||||
if layer_match.bypass_op:
|
||||
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
|
||||
_InsertQuantOp(
|
||||
add_context,
|
||||
layer_match.activation_op,
|
||||
consumer_ops,
|
||||
name='act_quant',
|
||||
moving_avg=True,
|
||||
init_min=0.0,
|
||||
bits=self.activation_bits,
|
||||
narrow_range=False)
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
bits=activation_bits,
|
||||
vars_collection=vars_collection)
|
||||
|
||||
# When a bypass connection was found, also quantize Add op input.
|
||||
if add_op:
|
||||
def _QuantizeAddInput(add_input):
|
||||
if folded:
|
||||
return add_input.op.name.endswith('/add_fold')
|
||||
else:
|
||||
return add_input.op.name.startswith(original_context + '/')
|
||||
# Quantize the inputs and output to the bypass (if it exists). The input to
|
||||
# the bypass is the bias add, and the output is the activation.
|
||||
if layer_match.bypass_op is not None:
|
||||
_InsertQuantOp(
|
||||
context,
|
||||
layer_match.bias_add_op, [layer_match.bypass_op],
|
||||
name='conv_quant',
|
||||
moving_avg=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
vars_collection=vars_collection,
|
||||
bits=activation_bits)
|
||||
_InsertQuantOp(
|
||||
add_context,
|
||||
layer_match.bypass_op,
|
||||
input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
|
||||
name='add_quant',
|
||||
moving_avg=True,
|
||||
bits=activation_bits)
|
||||
|
||||
for add_input in add_op.inputs:
|
||||
if _QuantizeAddInput(add_input):
|
||||
self._InsertQuantOp(
|
||||
original_context,
|
||||
add_input.op, [add_op],
|
||||
name='conv_quant',
|
||||
moving_avg=True,
|
||||
bits=self.activation_bits,
|
||||
narrow_range=False)
|
||||
|
||||
def _GetReluAndAddOperations(self, context, op):
|
||||
"""Looks up a Relu* and Add operations in given context.
|
||||
def _FindLayersToQuantize(graph):
|
||||
"""Matches layers in graph to quantize.
|
||||
|
||||
Args:
|
||||
context: Context where to look for operations.
|
||||
op: Operation to quantize.
|
||||
Args:
|
||||
graph: Graph to perform match on.
|
||||
|
||||
Returns:
|
||||
A triplet (Operation, Operation, string), the first element is an end
|
||||
point operation, the second is Add operation (optional), the third element
|
||||
is string context where the Add operation was found (optional).
|
||||
Yields:
|
||||
_LayerMatches.
|
||||
"""
|
||||
input_pattern = graph_matcher.OpTypePattern('*')
|
||||
weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES))
|
||||
weight_pattern = graph_matcher.OpTypePattern(
|
||||
'Identity', inputs=[weight_var_pattern])
|
||||
|
||||
Raises:
|
||||
ValueError: When operations cannot be found.
|
||||
"""
|
||||
activation_op = common.GetEndpointActivationOp(self.graph, context)
|
||||
if activation_op:
|
||||
return activation_op, None, None
|
||||
folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
|
||||
|
||||
if '/' in context:
|
||||
# If no activation op is there, look for them one level up.
|
||||
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
|
||||
activation_op = common.GetEndpointActivationOp(self.graph, add_context)
|
||||
if not activation_op:
|
||||
# Still no Relu, can happen on the top layer, just find the next node up,
|
||||
# make sure it is BiasAdd.
|
||||
consumers = [c for outp in op.outputs for c in outp.consumers()]
|
||||
if len(consumers) != 1 or consumers[0].type != 'BiasAdd':
|
||||
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
|
||||
return consumers[0], None, None
|
||||
if add_context:
|
||||
add_op = self.GetOperationByNamesDontThrow([
|
||||
add_context + '/Add', add_context + '/add'])
|
||||
return activation_op, add_op, add_context
|
||||
else:
|
||||
raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type))
|
||||
# The weights inputs to the layer operation can either be from the Variable or
|
||||
# the folded weight (Mul).
|
||||
layer_pattern = graph_matcher.OpTypePattern(
|
||||
'|'.join(_QUANTIZABLE_TYPES),
|
||||
inputs=[
|
||||
input_pattern,
|
||||
graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern])
|
||||
])
|
||||
|
||||
def GetOperationByNameDontThrow(self, name):
|
||||
"""Returns an Operation with the given name.
|
||||
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
|
||||
'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
|
||||
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
|
||||
'Add', inputs=[folded_bias_mul_pattern,
|
||||
graph_matcher.OpTypePattern('*')])
|
||||
folded_bias_add_pattern = graph_matcher.OpTypePattern(
|
||||
'Add',
|
||||
inputs=[
|
||||
post_layer_op_correction_pattern,
|
||||
graph_matcher.OpTypePattern('*')
|
||||
])
|
||||
|
||||
Args:
|
||||
name: Name of Operation to return.
|
||||
bias_add_pattern = graph_matcher.OpTypePattern(
|
||||
'Add|BiasAdd', inputs=[layer_pattern, '*'])
|
||||
|
||||
Returns:
|
||||
The Operation with the given name. None if the name does not correspond to
|
||||
any operation in the graph.
|
||||
"""
|
||||
try:
|
||||
return self.graph.get_operation_by_name(name)
|
||||
except KeyError:
|
||||
return None
|
||||
# The bias can come from the bias add or the folded bias add.
|
||||
bypass_pattern_a = graph_matcher.OpTypePattern(
|
||||
'Add',
|
||||
inputs=[
|
||||
graph_matcher.OneofPattern(
|
||||
[bias_add_pattern, folded_bias_add_pattern]), '*'
|
||||
])
|
||||
bypass_pattern_b = graph_matcher.OpTypePattern(
|
||||
'Add',
|
||||
inputs=[
|
||||
'*',
|
||||
graph_matcher.OneofPattern(
|
||||
[bias_add_pattern, folded_bias_add_pattern])
|
||||
])
|
||||
|
||||
def GetOperationByNamesDontThrow(self, names):
|
||||
"""Returns an Operation with one of the given names.
|
||||
# The input to the activation can come from bias add, fold bias add or the
|
||||
# bypasses.
|
||||
activation_pattern = graph_matcher.OpTypePattern(
|
||||
'|'.join(_ACTIVATION_TYPES),
|
||||
inputs=[
|
||||
graph_matcher.OneofPattern([
|
||||
bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
|
||||
bypass_pattern_b
|
||||
])
|
||||
])
|
||||
|
||||
Args:
|
||||
names: Names of Operation to return.
|
||||
layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
|
||||
for match_result in layer_matcher.match_graph(graph):
|
||||
layer_op = match_result.get_op(layer_pattern)
|
||||
weight_tensor = match_result.get_tensor(weight_pattern)
|
||||
if weight_tensor is None:
|
||||
weight_tensor = match_result.get_tensor(folded_weight_pattern)
|
||||
activation_op = match_result.get_op(activation_pattern)
|
||||
bias_add_op = match_result.get_op(bias_add_pattern)
|
||||
if bias_add_op is None:
|
||||
bias_add_op = match_result.get_op(folded_bias_add_pattern)
|
||||
bypass_op = match_result.get_op(bypass_pattern_a)
|
||||
if bypass_op is None:
|
||||
bypass_op = match_result.get_op(bypass_pattern_b)
|
||||
yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
|
||||
bias_add_op)
|
||||
|
||||
Returns:
|
||||
The Operation with one of the given names. None if none of the names
|
||||
corresponds to any operation in the graph.
|
||||
"""
|
||||
for name in names:
|
||||
op = self.GetOperationByNameDontThrow(name)
|
||||
if op is not None:
|
||||
return op
|
||||
return None
|
||||
|
||||
def _InsertQuantOp(
|
||||
self,
|
||||
context,
|
||||
producer,
|
||||
consumers,
|
||||
name,
|
||||
moving_avg=True,
|
||||
init_min=-6.0,
|
||||
init_max=6.0,
|
||||
delay_requested=True,
|
||||
bits=8,
|
||||
narrow_range=False,):
|
||||
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
||||
class _LayerMatch(object):
|
||||
"""Contains all information related to a matched Layer."""
|
||||
|
||||
Args:
|
||||
context: Context where producer and consumer operations are nested.
|
||||
producer: Producer operation of the pairs where quantization will be
|
||||
inserted.
|
||||
consumers: Consumer operations of the pairs.
|
||||
name: Name for the new quantization op within the context.
|
||||
moving_avg: Specifies whether to use exponential moving average or just
|
||||
the last value seen.
|
||||
init_min: Starting minimum value for the new quantization op.
|
||||
init_max: Starting maximum value for the new quantization op.
|
||||
delay_requested: If true, implement quantization delay where needed.
|
||||
False value explicitly disables delay quantization everywhere.
|
||||
bits: Number of bits to use for quantization, must be between 2 and 8.
|
||||
narrow_range: Whether to use the narrow quantization range
|
||||
def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
|
||||
bias_add_op):
|
||||
self._layer_op = layer_op
|
||||
self._weight_tensor = weight_tensor
|
||||
self._activation_op = activation_op
|
||||
self._bypass_op = bypass_op
|
||||
self._bias_add_op = bias_add_op
|
||||
|
||||
@property
|
||||
def layer_op(self):
|
||||
return self._layer_op
|
||||
|
||||
@property
|
||||
def weight_tensor(self):
|
||||
return self._weight_tensor
|
||||
|
||||
@property
|
||||
def activation_op(self):
|
||||
return self._activation_op
|
||||
|
||||
@property
|
||||
def bypass_op(self):
|
||||
return self._bypass_op
|
||||
|
||||
@property
|
||||
def bias_add_op(self):
|
||||
return self._bias_add_op
|
||||
|
||||
|
||||
def _InsertQuantOp(context,
|
||||
producer,
|
||||
consumers,
|
||||
name,
|
||||
moving_avg=True,
|
||||
init_min=-6.0,
|
||||
init_max=6.0,
|
||||
bits=8,
|
||||
ema_decay=0.999,
|
||||
quant_delay=None,
|
||||
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
is_training=True,
|
||||
narrow_range=False):
|
||||
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
||||
|
||||
Args:
|
||||
context: Context w,here producer and consumer operations are nested.
|
||||
producer: Producer operation of the pairs where quantization will be
|
||||
inserted.
|
||||
consumers: Consumer operations of the pairs.
|
||||
name: Name for the new quantization op within the context.
|
||||
moving_avg: Specifies whether to use exponential moving average or just
|
||||
the last value seen.
|
||||
init_min: Starting minimum value for the new quantization op.
|
||||
init_max: Starting maximum value for the new quantization op.
|
||||
bits: Number of bits to use for quantization, must be between 2 and 8.
|
||||
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
||||
quantization intervals for quantizing activations (see here about EMA:
|
||||
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
|
||||
quant_delay: (Optional, default None) Int, count of global steps for which
|
||||
to delay quantization. This helps weights stabilize at the start of
|
||||
training.
|
||||
vars_collection: (Optional) Collection where to store the variables for
|
||||
quantization interval ends.
|
||||
is_training: (Optional) Whether quantizing training graph or eval graph.
|
||||
narrow_range: Whether to use the narrow quantization range
|
||||
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
|
||||
Raises:
|
||||
ValueError: When producer operation is not directly connected to the
|
||||
consumer operation.
|
||||
"""
|
||||
scope = context + '/' + name
|
||||
inputs = producer.outputs[0]
|
||||
if moving_avg:
|
||||
quant = (quant_ops.MovingAvgQuantize(
|
||||
inputs,
|
||||
init_min=init_min,
|
||||
init_max=init_max,
|
||||
ema_decay=self.ema_decay,
|
||||
is_training=self.is_training,
|
||||
num_bits=bits,
|
||||
narrow_range=narrow_range,
|
||||
updates_collection=_UPDATE_QUANT_OPS,
|
||||
vars_collection=self.vars_collection,
|
||||
scope=scope))
|
||||
else:
|
||||
quant = (quant_ops.LastValueQuantize(
|
||||
inputs,
|
||||
init_min=init_min,
|
||||
init_max=init_max,
|
||||
is_training=self.is_training,
|
||||
num_bits=bits,
|
||||
narrow_range=narrow_range,
|
||||
updates_collection=_UPDATE_QUANT_OPS,
|
||||
vars_collection=self.vars_collection,
|
||||
scope=scope))
|
||||
Raises:
|
||||
ValueError: When producer operation is not directly connected to the
|
||||
consumer operation.
|
||||
"""
|
||||
name_prefix = _AddContextToName(context, name)
|
||||
inputs = producer.outputs[0]
|
||||
if moving_avg:
|
||||
quant = (
|
||||
quant_ops.MovingAvgQuantize(
|
||||
inputs,
|
||||
init_min=init_min,
|
||||
init_max=init_max,
|
||||
ema_decay=ema_decay,
|
||||
is_training=is_training,
|
||||
num_bits=bits,
|
||||
narrow_range=narrow_range,
|
||||
vars_collection=vars_collection,
|
||||
name_prefix=name_prefix))
|
||||
else:
|
||||
quant = (
|
||||
quant_ops.LastValueQuantize(
|
||||
inputs,
|
||||
init_min=init_min,
|
||||
init_max=init_max,
|
||||
is_training=is_training,
|
||||
num_bits=bits,
|
||||
narrow_range=narrow_range,
|
||||
vars_collection=vars_collection,
|
||||
name_prefix=name_prefix))
|
||||
|
||||
if delay_requested and self.quant_delay and self.quant_delay > 0:
|
||||
activate_quant = math_ops.greater_equal(
|
||||
training_util.get_or_create_global_step(),
|
||||
self.quant_delay,
|
||||
name=scope + '/activate_quant')
|
||||
quant = control_flow_ops.cond(
|
||||
activate_quant,
|
||||
lambda: quant,
|
||||
lambda: inputs,
|
||||
name=scope + '/delayed_quant')
|
||||
if quant_delay and quant_delay > 0:
|
||||
activate_quant = math_ops.greater_equal(
|
||||
training_util.get_or_create_global_step(),
|
||||
quant_delay,
|
||||
name=name_prefix + '/activate_quant')
|
||||
quant = control_flow_ops.cond(
|
||||
activate_quant,
|
||||
lambda: quant,
|
||||
lambda: inputs,
|
||||
name=name_prefix + '/delayed_quant')
|
||||
|
||||
nodes_modified_count = graph_editor.reroute_ts(
|
||||
[quant], [inputs], can_modify=consumers)
|
||||
if nodes_modified_count != len(consumers):
|
||||
raise ValueError('Some inputs not quantized for ops: [%s]' %
|
||||
', '.join([consumer.name for consumer in consumers]))
|
||||
nodes_modified_count = graph_editor.reroute_ts(
|
||||
[quant], [inputs], can_modify=consumers)
|
||||
if nodes_modified_count != len(consumers):
|
||||
raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
|
||||
[consumer.name for consumer in consumers]))
|
||||
|
||||
|
||||
def _GetContextFromOp(op):
|
||||
"""Gets the root context name from the op name."""
|
||||
context_re = re.search(r'^(.*)/([^/]+)', op.name)
|
||||
if context_re:
|
||||
return context_re.group(1)
|
||||
return ''
|
||||
|
||||
|
||||
def _AddContextToName(context, name):
|
||||
"""Adds the context to the name if it exists."""
|
||||
if not context:
|
||||
return name
|
||||
return context + '/' + name
|
||||
|
@ -101,7 +101,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
|
||||
]
|
||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||
output_op_name = scope + '/Conv2D'
|
||||
output_op_name = (
|
||||
scope + '/weights_quant/delayed_quant/Switch_1'
|
||||
if delay else scope + '/Conv2D')
|
||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||
|
||||
if with_bypass:
|
||||
@ -176,7 +178,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
|
||||
]
|
||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||
output_op_name = scope + '/MatMul'
|
||||
output_op_name = (
|
||||
scope + '/weights_quant/delayed_quant/Switch_1'
|
||||
if delay else scope + '/MatMul')
|
||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||
|
||||
if with_bypass:
|
||||
@ -252,7 +256,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
scope + '/depthwise_weights/read'
|
||||
]
|
||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||
output_op_name = scope + '/depthwise'
|
||||
output_op_name = (
|
||||
scope + '/weights_quant/delayed_quant/Switch_1'
|
||||
if delay else scope + '/depthwise')
|
||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||
|
||||
if with_bypass:
|
||||
@ -316,6 +322,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
for params in parameters_list:
|
||||
test_fn(params[0], params[1], params[2], params[3], params[4])
|
||||
|
||||
def testQuantize_Conv2dWithBatchNorm(self):
|
||||
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
|
||||
|
||||
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay, fused_batch_norm):
|
||||
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
||||
@ -329,39 +338,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
"""
|
||||
self._testQuantize_Conv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=True)
|
||||
self._testQuantize_Conv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=False)
|
||||
|
||||
def testQuantize_Conv2dWithBatchNorm(self):
|
||||
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
|
||||
|
||||
def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay, fused_batch_norm,
|
||||
use_ema):
|
||||
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
activation: Callable that returns an Operation, a factory method for the
|
||||
Activation.
|
||||
activation_op_name: String, name of the Activation operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
training.create_global_step(graph)
|
||||
@ -394,23 +370,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(graph)
|
||||
|
||||
quantize.Quantize(
|
||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
||||
quantize.Quantize(graph, quant_delay=delay)
|
||||
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
||||
quantization_node_name)
|
||||
self.assertEqual(weights_quant.type, quantization_node_name)
|
||||
expected_inputs = [
|
||||
scope + '/weights_quant/' + ('AssignMinEma'
|
||||
if use_ema else 'AssignMinLast'),
|
||||
scope + '/weights_quant/' + ('AssignMaxEma'
|
||||
if use_ema else 'AssignMaxLast'),
|
||||
scope + '/mul_fold'
|
||||
scope + '/weights_quant/' + 'AssignMinLast',
|
||||
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
|
||||
]
|
||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
||||
if (delay and use_ema) else '/Conv2D_Fold')
|
||||
if delay else '/Conv2D_Fold')
|
||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||
|
||||
if with_bypass:
|
||||
@ -438,6 +410,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
if delay else 'control_dependency')
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def testQuantize_FCWithBatchNorm(self):
|
||||
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
|
||||
|
||||
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay, fused_batch_norm):
|
||||
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
||||
@ -451,39 +426,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
"""
|
||||
self._testQuantize_FCWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=True)
|
||||
self._testQuantize_FCWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=False)
|
||||
|
||||
def testQuantize_FCWithBatchNorm(self):
|
||||
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
|
||||
|
||||
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay, fused_batch_norm,
|
||||
use_ema):
|
||||
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
activation: Callable that returns an Operation, a factory method for the
|
||||
Activation.
|
||||
activation_op_name: String, name of the Activation operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
training.create_global_step(graph)
|
||||
@ -513,23 +455,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(graph)
|
||||
|
||||
quantize.Quantize(
|
||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
||||
quantize.Quantize(graph, quant_delay=delay)
|
||||
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
||||
quantization_node_name)
|
||||
self.assertEqual(weights_quant.type, quantization_node_name)
|
||||
expected_inputs = [
|
||||
scope + '/weights_quant/' + ('AssignMinEma'
|
||||
if use_ema else 'AssignMinLast'),
|
||||
scope + '/weights_quant/' + ('AssignMaxEma'
|
||||
if use_ema else 'AssignMaxLast'),
|
||||
scope + '/mul_fold'
|
||||
scope + '/weights_quant/' + 'AssignMinLast',
|
||||
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
|
||||
]
|
||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
||||
if delay and use_ema else '/MatMul_Fold')
|
||||
if delay else '/MatMul_Fold')
|
||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||
|
||||
if with_bypass:
|
||||
@ -557,6 +495,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
if delay else 'control_dependency')
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
|
||||
self._RunBatchNormTestOverParameters(
|
||||
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
|
||||
|
||||
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
self, activation, activation_op_name, with_bypass, delay,
|
||||
fused_batch_norm):
|
||||
@ -571,40 +513,6 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
"""
|
||||
self._testQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=True)
|
||||
self._testQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=False)
|
||||
|
||||
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
|
||||
self._RunBatchNormTestOverParameters(
|
||||
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
|
||||
|
||||
def _testQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
self, activation, activation_op_name, with_bypass, delay,
|
||||
fused_batch_norm, use_ema):
|
||||
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
activation: Callable that returns an Operation, a factory method for the
|
||||
Activation.
|
||||
activation_op_name: String, name of the Activation operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
training.create_global_step(graph)
|
||||
@ -637,22 +545,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(graph)
|
||||
|
||||
quantize.Quantize(
|
||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
||||
quantize.Quantize(graph, quant_delay=delay)
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
|
||||
quantization_node_name)
|
||||
self.assertEqual(weights_quant.type, quantization_node_name)
|
||||
expected_inputs = [
|
||||
scope + '/weights_quant/' + ('AssignMinEma'
|
||||
if use_ema else 'AssignMinLast'),
|
||||
scope + '/weights_quant/' + ('AssignMaxEma'
|
||||
if use_ema else 'AssignMaxLast'),
|
||||
scope + '/mul_fold'
|
||||
scope + '/weights_quant/' + 'AssignMinLast',
|
||||
scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
|
||||
]
|
||||
self._AssertInputOpsAre(weights_quant, expected_inputs)
|
||||
output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
|
||||
if delay and use_ema else '/depthwise_Fold')
|
||||
if delay else '/depthwise_Fold')
|
||||
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
|
||||
|
||||
if with_bypass:
|
||||
|
@ -45,13 +45,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
activation_fn=None, scope='test')
|
||||
relu = nn_ops.relu6(inputs)
|
||||
|
||||
context = quantize._QuantizeContext(graph=graph, weight_bits=8,
|
||||
weight_narrow_range=True,
|
||||
activation_bits=8)
|
||||
# Inserting a quantization op between two unconnected ops should fail with
|
||||
# ValueError.
|
||||
with self.assertRaises(ValueError) as err:
|
||||
context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp')
|
||||
quantize._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp')
|
||||
self.assertEqual(
|
||||
str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
|
||||
|
||||
@ -70,8 +67,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with ops.control_dependencies([update_barrier]):
|
||||
array_ops.identity(node, name='control_dependency')
|
||||
|
||||
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
|
||||
activation_bits=8)
|
||||
quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8)
|
||||
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
||||
@ -94,8 +90,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with ops.control_dependencies([update_barrier]):
|
||||
array_ops.identity(node, name='control_dependency')
|
||||
|
||||
quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
|
||||
activation_bits=8)
|
||||
quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8)
|
||||
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
add_quant = graph.get_operation_by_name('test/add_quant/' +
|
||||
|
Loading…
x
Reference in New Issue
Block a user