diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b7d525a1fa2..ada336e6235 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -75,7 +75,9 @@ py_library( ":graph_matcher", ":input_to_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:layers", "//tensorflow/python:math_ops", @@ -83,6 +85,7 @@ py_library( "//tensorflow/python:nn_ops", "//tensorflow/python:ops", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variables", ], ) @@ -162,7 +165,6 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variables", @@ -174,7 +176,7 @@ py_library( srcs = ["python/quantize.py"], srcs_version = "PY2AND3", deps = [ - ":common", + ":graph_matcher", ":input_to_ops", ":quant_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index f80d427ff0a..0a8e35080cb 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -53,7 +53,7 @@ def LastValueQuantize(inputs, init_max=6.0, updates_collection=ops.GraphKeys.UPDATE_OPS, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - scope=None, + name_prefix='LastValueQuant', reuse=None, is_training=True, num_bits=8, @@ -73,7 +73,7 @@ def LastValueQuantize(inputs, computation. vars_collection: (Optional) collection where to store variables for quantization interval ends. - scope: Optional scope for variable_scope. + name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. @@ -84,13 +84,13 @@ def LastValueQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - scope, 'LastValueQuantize', values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse): input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' - ' scope: %s' % (input_shape, scope)) + ' scope: %s' % (input_shape, name_prefix)) min_max_shape = [input_shape[-1]] else: min_max_shape = [] @@ -165,7 +165,7 @@ def MovingAvgQuantize(inputs, ema_decay=0.999, updates_collection=ops.GraphKeys.UPDATE_OPS, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - scope=None, + name_prefix='MovingAvgQuantize', reuse=None, is_training=True, num_bits=8, @@ -186,7 +186,7 @@ def MovingAvgQuantize(inputs, computation. vars_collection: (Optional) collection where to store variables for quantization interval ends. - scope: Optional scope for variable_scope. + name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. @@ -197,13 +197,13 @@ def MovingAvgQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse): input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' - ' scope: %s' % (input_shape, scope)) + ' scope: %s' % (input_shape, name_prefix)) min_max_shape = [input_shape[-1]] else: min_max_shape = [] diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 50a2b4c91c9..1a63b0a2ce0 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logic to update a Tensorflow model graph with quantization operations.""" +"""Logic to update a TensorFlow model graph with quantization operations.""" from __future__ import absolute_import from __future__ import division @@ -20,7 +20,7 @@ from __future__ import print_function import re from tensorflow.contrib import graph_editor -from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import ops @@ -28,30 +28,29 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import training_util -# Operation types used to select operations of interest. +# Quantizable operation types that are supported by the quantization rewrite. _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} -# Custom key for storing and retrieving update ops used by quantizing nodes. -_UPDATE_QUANT_OPS = 'update_quant_ops' +# Activations that are supported by the quantization rewrite. +_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'} + +# Weight types that are supported by the quantization rewrite. +# TODO(suharshs): Add support for ResourceVariable. +_WEIGHT_TYPES = {'Variable', 'VariableV2'} def Quantize(graph, weight_bits=8, - weight_narrow_range=False, activation_bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, - quantize_folded_weights_use_ema=False): + is_training=True): """Updates graph with quantization operations. Args: graph: Graph to modify. weight_bits: Number of bits to use for quantizing weights. - weight_narrow_range: Whether to use a more efficient narrow range for - weights quantization. With weight_narrow_range true, the range is - [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. activation_bits: Number of bits to use for quantizing activations. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: @@ -62,345 +61,274 @@ def Quantize(graph, vars_collection: (Optional) Collection where to store the variables for quantization interval ends. is_training: (Optional) Whether quantizing training graph or eval graph. - quantize_folded_weights_use_ema: (Optional, default False) Whether to - quantize weights after batchnorm-folding with exponential average - quantization. Raises: ValueError: When quantization fails. """ - context = _QuantizeContext(graph, weight_bits, weight_narrow_range, - activation_bits, ema_decay, quant_delay, - vars_collection, is_training, - quantize_folded_weights_use_ema) - - graph_ops = graph.get_operations() - - # Filter out backprop and summary related operations, leave only interesting - # op types. - def _IsInterestingOpWithWeights(op): - return (op.type in _QUANTIZABLE_TYPES and - not op.name.startswith(common.SKIPPED_PREFIXES)) - - for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)): - if op.name.endswith('/depthwise'): - # Separable convolution may consist of 2 convolution nodes. If so, skip - # .../depthwise and only quantize the top one. - separable_conv = context.GetOperationByNameDontThrow( - op.name[:-len('/depthwise')]) - if separable_conv and separable_conv.type == 'Conv2D': - continue - # Quantize add ops that come after Conv2D or DepthwiseConv2dNative. - if op.type in ['Conv2D', 'DepthwiseConv2dNative']: - add_context_re = re.search(r'^(.*)/[^/]+/', op.name) - if add_context_re is not None: - context.add_contexts.add(add_context_re.group(1)) - if not op.name.endswith('_Fold'): - folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold') - # Do nothing if found, it will be quantized when it is iterated over. - if not folded_op: - context.QuantizeOpWithWeights(op, folded=False) - else: - context.QuantizeOpWithWeights(op, folded=True) - - context.QuantizeAddContexts() - - # Once all quantization ops have been inserted in the graph, collect update - # ops for their variables and modify the TF Slim update barrier (see - # https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py) - # to depend on them. - try: - update_barrier = graph.get_operation_by_name('update_barrier') - except KeyError: - # In evaluation graph, this barrier may not exist. - return None - update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS) - graph_editor.add_control_inputs(update_barrier, update_quant_ops) - - -class _QuantizeContext(object): - """Context holds references needed for quantization.""" - - def __init__(self, - graph, - weight_bits, - weight_narrow_range, - activation_bits, - ema_decay=0.999, - quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, - quantize_folded_weights_use_ema=False): - """Initializes context to hold references needed for quantization. - - Args: - graph: Graph to modify. - weight_bits: Number of bits to use for quantizing weights. - weight_narrow_range: Whether to use a more efficient narrow range for - weights quantization. With weight_narrow_range true, the range is - [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. - activation_bits: Number of bits to use for quantizing activations. - ema_decay: (Optional) Float, EMA decay parameter. - quant_delay: (Optional, default None) Int, count of global steps for which - to delay quantization. This helps weights stabilize at the start of - training. - vars_collection: (Optional) Collection where to store the variables for - quantization interval ends. - is_training: (Optional) Whether quantizing training or eval graph. - quantize_folded_weights_use_ema: (Optional, default False) Whether to - quantize weights after batchnorm-folding with exponential average - quantization. - """ - self.graph = graph - self.weight_bits = weight_bits - self.weight_narrow_range = weight_narrow_range - self.activation_bits = activation_bits - self.ema_decay = ema_decay - self.quant_delay = quant_delay - self.vars_collection = vars_collection - self.is_training = is_training - self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema - self.input_to_ops_map = input_to_ops.InputToOps(graph) - self.add_contexts = set() - - def QuantizeAddContexts(self): - """Quantizes all add ops in self.add_contexts.""" - # Loop through sorted self.add_contexts so that op creation is - # deterministic. This is needed when using multiple worker replicas so that - # the ops can be initialized consistently. - for add_context in sorted(self.add_contexts): - add_op = self.GetOperationByNamesDontThrow([ - add_context + '/Add', add_context + '/add']) - if add_op is not None: - self._InsertQuantOp( - add_context, - add_op, - self.input_to_ops_map.ConsumerOperations(add_op), - name='add_quant', - moving_avg=True, - bits=self.activation_bits, - narrow_range=False) - - def QuantizeOpWithWeights(self, op, folded): - """Quantizes around the specific operation with or without batch norm. - - Args: - op: Operation to quantize. - folded: Operation has been folded and needs special handling if True. - Raises: - ValueError: When quantization fails. - """ - # Op name component before the last slash will be used as context. - context = re.search(r'^(.*)/([^/]+)', op.name).group(1) - - # Quantize weights. - if folded: - producer_op = self.graph.get_operation_by_name(context + '/mul_fold') - else: - try: - input_idx = next(i for i, v in enumerate(op.inputs) - if '/weights/' in v.name or - '/depthwise_weights' in v.name) - except StopIteration: - raise ValueError('No inputs to quantize for op: %s' % op) - producer_op = op.inputs[input_idx].op - - # If batch norm is used, the folded weights depend on the batch std, hence - # it is sensible to use EMA during training to smooth out the noise. This is - # controlled by the flag quantize_folded_weights_use_ema. Its default is - # False for backward compatibility. - # If there is no batch norm, weights do not depend on the batch and using - # the latest value of min and max is more efficient. - weight_use_ema = folded and self.quantize_folded_weights_use_ema - self._InsertQuantOp( + input_to_ops_map = input_to_ops.InputToOps(graph) + for layer_match in _FindLayersToQuantize(graph): + # Quantize the weights. + context = _GetContextFromOp(layer_match.layer_op) + _InsertQuantOp( context, - producer_op, [op], + layer_match.weight_tensor.op, [layer_match.layer_op], name='weights_quant', - moving_avg=weight_use_ema, - delay_requested=weight_use_ema, - bits=self.weight_bits, - narrow_range=self.weight_narrow_range) + moving_avg=False, + bits=weight_bits, + ema_decay=ema_decay, + quant_delay=quant_delay, + is_training=is_training, + narrow_range=True, + vars_collection=vars_collection) - # Important: do not quantize biases here. During inference they are - # quantized to 32 bits, which is much finer than 8 bit quantization and - # depends on weight and input activation ranges. - - # Find activation and (optionally) Add operations to quantize. - activation_op, add_op, add_context = self._GetReluAndAddOperations(context, - op) - if add_op: - original_context = context - context = add_context - - # Quantize activation outputs. - consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op) - self._InsertQuantOp( - context, - activation_op, + # Quantize the activations. + consumer_ops = input_to_ops_map.ConsumerOperations( + layer_match.activation_op) + add_context = context + if layer_match.bypass_op: + add_context = re.search(r'^(.*)/([^/]+)', context).group(1) + _InsertQuantOp( + add_context, + layer_match.activation_op, consumer_ops, name='act_quant', moving_avg=True, init_min=0.0, - bits=self.activation_bits, - narrow_range=False) + ema_decay=ema_decay, + quant_delay=quant_delay, + bits=activation_bits, + vars_collection=vars_collection) - # When a bypass connection was found, also quantize Add op input. - if add_op: - def _QuantizeAddInput(add_input): - if folded: - return add_input.op.name.endswith('/add_fold') - else: - return add_input.op.name.startswith(original_context + '/') + # Quantize the inputs and output to the bypass (if it exists). The input to + # the bypass is the bias add, and the output is the activation. + if layer_match.bypass_op is not None: + _InsertQuantOp( + context, + layer_match.bias_add_op, [layer_match.bypass_op], + name='conv_quant', + moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits) + _InsertQuantOp( + add_context, + layer_match.bypass_op, + input_to_ops_map.ConsumerOperations(layer_match.bypass_op), + name='add_quant', + moving_avg=True, + bits=activation_bits) - for add_input in add_op.inputs: - if _QuantizeAddInput(add_input): - self._InsertQuantOp( - original_context, - add_input.op, [add_op], - name='conv_quant', - moving_avg=True, - bits=self.activation_bits, - narrow_range=False) - def _GetReluAndAddOperations(self, context, op): - """Looks up a Relu* and Add operations in given context. +def _FindLayersToQuantize(graph): + """Matches layers in graph to quantize. - Args: - context: Context where to look for operations. - op: Operation to quantize. + Args: + graph: Graph to perform match on. - Returns: - A triplet (Operation, Operation, string), the first element is an end - point operation, the second is Add operation (optional), the third element - is string context where the Add operation was found (optional). + Yields: + _LayerMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES)) + weight_pattern = graph_matcher.OpTypePattern( + 'Identity', inputs=[weight_var_pattern]) - Raises: - ValueError: When operations cannot be found. - """ - activation_op = common.GetEndpointActivationOp(self.graph, context) - if activation_op: - return activation_op, None, None + folded_weight_pattern = graph_matcher.OpTypePattern('Mul') - if '/' in context: - # If no activation op is there, look for them one level up. - add_context = re.search(r'^(.*)/([^/]+)', context).group(1) - activation_op = common.GetEndpointActivationOp(self.graph, add_context) - if not activation_op: - # Still no Relu, can happen on the top layer, just find the next node up, - # make sure it is BiasAdd. - consumers = [c for outp in op.outputs for c in outp.consumers()] - if len(consumers) != 1 or consumers[0].type != 'BiasAdd': - raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) - return consumers[0], None, None - if add_context: - add_op = self.GetOperationByNamesDontThrow([ - add_context + '/Add', add_context + '/add']) - return activation_op, add_op, add_context - else: - raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) + # The weights inputs to the layer operation can either be from the Variable or + # the folded weight (Mul). + layer_pattern = graph_matcher.OpTypePattern( + '|'.join(_QUANTIZABLE_TYPES), + inputs=[ + input_pattern, + graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern]) + ]) - def GetOperationByNameDontThrow(self, name): - """Returns an Operation with the given name. + folded_bias_mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern]) + post_layer_op_correction_pattern = graph_matcher.OpTypePattern( + 'Add', inputs=[folded_bias_mul_pattern, + graph_matcher.OpTypePattern('*')]) + folded_bias_add_pattern = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + post_layer_op_correction_pattern, + graph_matcher.OpTypePattern('*') + ]) - Args: - name: Name of Operation to return. + bias_add_pattern = graph_matcher.OpTypePattern( + 'Add|BiasAdd', inputs=[layer_pattern, '*']) - Returns: - The Operation with the given name. None if the name does not correspond to - any operation in the graph. - """ - try: - return self.graph.get_operation_by_name(name) - except KeyError: - return None + # The bias can come from the bias add or the folded bias add. + bypass_pattern_a = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + graph_matcher.OneofPattern( + [bias_add_pattern, folded_bias_add_pattern]), '*' + ]) + bypass_pattern_b = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + '*', + graph_matcher.OneofPattern( + [bias_add_pattern, folded_bias_add_pattern]) + ]) - def GetOperationByNamesDontThrow(self, names): - """Returns an Operation with one of the given names. + # The input to the activation can come from bias add, fold bias add or the + # bypasses. + activation_pattern = graph_matcher.OpTypePattern( + '|'.join(_ACTIVATION_TYPES), + inputs=[ + graph_matcher.OneofPattern([ + bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a, + bypass_pattern_b + ]) + ]) - Args: - names: Names of Operation to return. + layer_matcher = graph_matcher.GraphMatcher(activation_pattern) + for match_result in layer_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(folded_weight_pattern) + activation_op = match_result.get_op(activation_pattern) + bias_add_op = match_result.get_op(bias_add_pattern) + if bias_add_op is None: + bias_add_op = match_result.get_op(folded_bias_add_pattern) + bypass_op = match_result.get_op(bypass_pattern_a) + if bypass_op is None: + bypass_op = match_result.get_op(bypass_pattern_b) + yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, + bias_add_op) - Returns: - The Operation with one of the given names. None if none of the names - corresponds to any operation in the graph. - """ - for name in names: - op = self.GetOperationByNameDontThrow(name) - if op is not None: - return op - return None - def _InsertQuantOp( - self, - context, - producer, - consumers, - name, - moving_avg=True, - init_min=-6.0, - init_max=6.0, - delay_requested=True, - bits=8, - narrow_range=False,): - """Inserts a quant op between a producer op and (multiple) consumer ops. +class _LayerMatch(object): + """Contains all information related to a matched Layer.""" - Args: - context: Context where producer and consumer operations are nested. - producer: Producer operation of the pairs where quantization will be - inserted. - consumers: Consumer operations of the pairs. - name: Name for the new quantization op within the context. - moving_avg: Specifies whether to use exponential moving average or just - the last value seen. - init_min: Starting minimum value for the new quantization op. - init_max: Starting maximum value for the new quantization op. - delay_requested: If true, implement quantization delay where needed. - False value explicitly disables delay quantization everywhere. - bits: Number of bits to use for quantization, must be between 2 and 8. - narrow_range: Whether to use the narrow quantization range + def __init__(self, layer_op, weight_tensor, activation_op, bypass_op, + bias_add_op): + self._layer_op = layer_op + self._weight_tensor = weight_tensor + self._activation_op = activation_op + self._bypass_op = bypass_op + self._bias_add_op = bias_add_op + + @property + def layer_op(self): + return self._layer_op + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def activation_op(self): + return self._activation_op + + @property + def bypass_op(self): + return self._bypass_op + + @property + def bias_add_op(self): + return self._bias_add_op + + +def _InsertQuantOp(context, + producer, + consumers, + name, + moving_avg=True, + init_min=-6.0, + init_max=6.0, + bits=8, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + is_training=True, + narrow_range=False): + """Inserts a quant op between a producer op and (multiple) consumer ops. + + Args: + context: Context w,here producer and consumer operations are nested. + producer: Producer operation of the pairs where quantization will be + inserted. + consumers: Consumer operations of the pairs. + name: Name for the new quantization op within the context. + moving_avg: Specifies whether to use exponential moving average or just + the last value seen. + init_min: Starting minimum value for the new quantization op. + init_max: Starting maximum value for the new quantization op. + bits: Number of bits to use for quantization, must be between 2 and 8. + ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update + quantization intervals for quantizing activations (see here about EMA: + https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + is_training: (Optional) Whether quantizing training graph or eval graph. + narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. - Raises: - ValueError: When producer operation is not directly connected to the - consumer operation. - """ - scope = context + '/' + name - inputs = producer.outputs[0] - if moving_avg: - quant = (quant_ops.MovingAvgQuantize( - inputs, - init_min=init_min, - init_max=init_max, - ema_decay=self.ema_decay, - is_training=self.is_training, - num_bits=bits, - narrow_range=narrow_range, - updates_collection=_UPDATE_QUANT_OPS, - vars_collection=self.vars_collection, - scope=scope)) - else: - quant = (quant_ops.LastValueQuantize( - inputs, - init_min=init_min, - init_max=init_max, - is_training=self.is_training, - num_bits=bits, - narrow_range=narrow_range, - updates_collection=_UPDATE_QUANT_OPS, - vars_collection=self.vars_collection, - scope=scope)) + Raises: + ValueError: When producer operation is not directly connected to the + consumer operation. + """ + name_prefix = _AddContextToName(context, name) + inputs = producer.outputs[0] + if moving_avg: + quant = ( + quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=ema_decay, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + else: + quant = ( + quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) - if delay_requested and self.quant_delay and self.quant_delay > 0: - activate_quant = math_ops.greater_equal( - training_util.get_or_create_global_step(), - self.quant_delay, - name=scope + '/activate_quant') - quant = control_flow_ops.cond( - activate_quant, - lambda: quant, - lambda: inputs, - name=scope + '/delayed_quant') + if quant_delay and quant_delay > 0: + activate_quant = math_ops.greater_equal( + training_util.get_or_create_global_step(), + quant_delay, + name=name_prefix + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=name_prefix + '/delayed_quant') - nodes_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) - if nodes_modified_count != len(consumers): - raise ValueError('Some inputs not quantized for ops: [%s]' % - ', '.join([consumer.name for consumer in consumers])) + nodes_modified_count = graph_editor.reroute_ts( + [quant], [inputs], can_modify=consumers) + if nodes_modified_count != len(consumers): + raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join( + [consumer.name for consumer in consumers])) + + +def _GetContextFromOp(op): + """Gets the root context name from the op name.""" + context_re = re.search(r'^(.*)/([^/]+)', op.name) + if context_re: + return context_re.group(1) + return '' + + +def _AddContextToName(context, name): + """Adds the context to the name if it exists.""" + if not context: + return name + return context + '/' + name diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 57dab03f162..f1fe322049d 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -101,7 +101,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/Conv2D' + output_op_name = ( + scope + '/weights_quant/delayed_quant/Switch_1' + if delay else scope + '/Conv2D') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -176,7 +178,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/MatMul' + output_op_name = ( + scope + '/weights_quant/delayed_quant/Switch_1' + if delay else scope + '/MatMul') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -252,7 +256,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/depthwise_weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/depthwise' + output_op_name = ( + scope + '/weights_quant/delayed_quant/Switch_1' + if delay else scope + '/depthwise') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -316,6 +322,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): for params in parameters_list: test_fn(params[0], params[1], params[2], params[3], params[4]) + def testQuantize_Conv2dWithBatchNorm(self): + self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. @@ -329,39 +338,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. """ - self._testQuantize_Conv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_Conv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - - def testQuantize_Conv2dWithBatchNorm(self): - self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) - - def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm, - use_ema): - """Tests quantization: inputs -> Conv2d with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. - """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) @@ -394,23 +370,19 @@ class QuantizeTest(test_util.TensorFlowTestCase): fold_batch_norms.FoldBatchNorms(graph) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if (delay and use_ema) else '/Conv2D_Fold') + if delay else '/Conv2D_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -438,6 +410,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + def testQuantize_FCWithBatchNorm(self): + self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) + def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> FC with batch norm -> Activation. @@ -451,39 +426,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. """ - self._testQuantize_FCWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_FCWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - - def testQuantize_FCWithBatchNorm(self): - self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) - - def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm, - use_ema): - """Tests quantization: inputs -> FC with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. - """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) @@ -513,23 +455,19 @@ class QuantizeTest(test_util.TensorFlowTestCase): fold_batch_norms.FoldBatchNorms(graph) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay and use_ema else '/MatMul_Fold') + if delay else '/MatMul_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -557,6 +495,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + def testQuantize_DepthwiseConv2dWithBatchNorm(self): + self._RunBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithBatchNorm) + def _TestQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, fused_batch_norm): @@ -571,40 +513,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. """ - self._testQuantize_DepthwiseConv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_DepthwiseConv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - - def testQuantize_DepthwiseConv2dWithBatchNorm(self): - self._RunBatchNormTestOverParameters( - self._TestQuantize_DepthwiseConv2dWithBatchNorm) - - def _testQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm, use_ema): - """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. - """ graph = ops.Graph() with graph.as_default(): training.create_global_step(graph) @@ -637,22 +545,18 @@ class QuantizeTest(test_util.TensorFlowTestCase): fold_batch_norms.FoldBatchNorms(graph) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay and use_ema else '/depthwise_Fold') + if delay else '/depthwise_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 1e4dd7cf67d..53cbd667410 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -45,13 +45,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, scope='test') relu = nn_ops.relu6(inputs) - context = quantize._QuantizeContext(graph=graph, weight_bits=8, - weight_narrow_range=True, - activation_bits=8) # Inserting a quantization op between two unconnected ops should fail with # ValueError. with self.assertRaises(ValueError) as err: - context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') + quantize._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') self.assertEqual( str(err.exception), 'Some inputs not quantized for ops: [Relu6]') @@ -70,8 +67,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, - activation_bits=8) + quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + @@ -94,8 +90,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, - activation_bits=8) + quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' +