Add narrow_range to DequantizeOp
PiperOrigin-RevId: 272068762
This commit is contained in:
parent
ddc1f64e08
commit
720a51829b
@ -15,7 +15,7 @@ END
|
||||
summary: "Dequantize the \'input\' tensor into a float Tensor."
|
||||
description: <<END
|
||||
[min_range, max_range] are scalar floats that specify the range for
|
||||
the 'input' data. The 'mode' attribute controls exactly which calculations are
|
||||
the output. The 'mode' attribute controls exactly which calculations are
|
||||
used to convert the float values to their quantized equivalents.
|
||||
|
||||
In 'MIN_COMBINED' mode, each value of the tensor will undergo the following:
|
||||
@ -47,45 +47,24 @@ const double offset_input = static_cast<double>(input) - lowest_quantized;
|
||||
result = range_min + ((input - numeric_limits<T>::min()) * range_scale)
|
||||
```
|
||||
|
||||
*SCALED mode Example*
|
||||
If the mode is `SCALED`, dequantization is performed by multiplying each
|
||||
input value by a scaling_factor. (Thus an input of 0 always maps to 0.0).
|
||||
|
||||
`SCALED` mode matches the quantization approach used in
|
||||
`QuantizeAndDequantize{V2|V3}`.
|
||||
The scaling_factor is determined from `min_range`, `max_range`, and
|
||||
`narrow_range` in a way that is compatible with `QuantizeAndDequantize{V2|V3}`
|
||||
and `QuantizeV2`, using the following algorithm:
|
||||
|
||||
If the mode is `SCALED`, we do not use the full range of the output type,
|
||||
choosing to elide the lowest possible value for symmetry (e.g., output range is
|
||||
-127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to
|
||||
0.
|
||||
|
||||
We first find the range of values in our tensor. The
|
||||
range we use is always centered on 0, so we find m such that
|
||||
```c++
|
||||
m = max(abs(input_min), abs(input_max))
|
||||
```
|
||||
|
||||
Our input tensor range is then `[-m, m]`.
|
||||
const int min_expected_T = std::numeric_limits<T>::min() +
|
||||
(narrow_range ? 1 : 0);
|
||||
const int max_expected_T = std::numeric_limits<T>::max();
|
||||
const float max_expected_T = std::numeric_limits<float>::max();
|
||||
|
||||
Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`.
|
||||
If T is signed, this is
|
||||
```
|
||||
num_bits = sizeof(T) * 8
|
||||
[min_fixed, max_fixed] =
|
||||
[-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]
|
||||
```
|
||||
|
||||
Otherwise, if T is unsigned, the fixed-point range is
|
||||
```
|
||||
[min_fixed, max_fixed] = [0, (1 << num_bits) - 1]
|
||||
```
|
||||
|
||||
From this we compute our scaling factor, s:
|
||||
```c++
|
||||
s = (2 * m) / (max_fixed - min_fixed)
|
||||
```
|
||||
|
||||
Now we can dequantize the elements of our tensor:
|
||||
```c++
|
||||
result = input * s
|
||||
const float scale_factor =
|
||||
(std::numeric_limits<T>::min() == 0) ? (max_range / max_expected_T)
|
||||
: std::max(min_range / min_expected_T,
|
||||
max_range / max_expected_T);
|
||||
```
|
||||
END
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ class DequantizeOp : public OpKernel {
|
||||
} else if (mode_string == "SCALED") {
|
||||
mode_ = QUANTIZE_MODE_SCALED;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
|
||||
}
|
||||
|
||||
@ -133,10 +134,12 @@ class DequantizeOp : public OpKernel {
|
||||
output);
|
||||
}
|
||||
} else if (mode_ == QUANTIZE_MODE_SCALED) {
|
||||
const int min_output_value =
|
||||
std::numeric_limits<T>::min() + (narrow_range_ ? 1 : 0);
|
||||
const float scale_factor =
|
||||
std::numeric_limits<T>::min() == 0
|
||||
? (max_range / std::numeric_limits<T>::max())
|
||||
: std::max(min_range / std::numeric_limits<T>::min(),
|
||||
: std::max(min_range / min_output_value,
|
||||
max_range / std::numeric_limits<T>::max());
|
||||
const auto& input_tensor = input.flat<T>();
|
||||
output->flat<float>() =
|
||||
@ -168,10 +171,12 @@ class DequantizeOp : public OpKernel {
|
||||
((input.template cast<float>() + half_range) * scale_factor) +
|
||||
min_range;
|
||||
} else if (mode_ == QUANTIZE_MODE_SCALED) {
|
||||
const int min_output_value =
|
||||
std::numeric_limits<T>::min() + (narrow_range_ ? 1 : 0);
|
||||
const float scale_factor =
|
||||
std::numeric_limits<T>::min() == 0
|
||||
? (max_range / std::numeric_limits<T>::max())
|
||||
: std::max(min_range / std::numeric_limits<T>::min(),
|
||||
: std::max(min_range / min_output_value,
|
||||
max_range / std::numeric_limits<T>::max());
|
||||
output.device(d) = input.template cast<float>() * scale_factor;
|
||||
}
|
||||
@ -180,6 +185,7 @@ class DequantizeOp : public OpKernel {
|
||||
private:
|
||||
int mode_;
|
||||
int axis_;
|
||||
bool narrow_range_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
|
@ -3073,6 +3073,7 @@ REGISTER_OP("Dequantize")
|
||||
.Output("output: float")
|
||||
.Attr("T: quantizedtype")
|
||||
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
|
||||
.Attr("narrow_range: bool = false")
|
||||
.Attr("axis: int = -1")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
int axis = -1;
|
||||
|
@ -4514,7 +4514,7 @@ def quantize_v2(
|
||||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
if compat.forward_compatible(2019, 9, 25) or axis >= 0:
|
||||
if compat.forward_compatible(2019, 9, 25) or axis >= 0 or narrow_range:
|
||||
return gen_array_ops.quantize_v2(
|
||||
input,
|
||||
min_range,
|
||||
@ -4579,14 +4579,14 @@ def quantize(
|
||||
@tf_export("quantization.dequantize", v1=["quantization.dequantize",
|
||||
"dequantize"])
|
||||
@deprecation.deprecated_endpoints("dequantize")
|
||||
def dequantize(
|
||||
def dequantize( # pylint: disable=missing-docstring
|
||||
input, # pylint: disable=redefined-builtin
|
||||
min_range,
|
||||
max_range,
|
||||
mode="MIN_COMBINED",
|
||||
name=None,
|
||||
axis=None):
|
||||
"""Dequantize tensor to the specified range."""
|
||||
axis=None,
|
||||
narrow_range=False):
|
||||
if axis is None:
|
||||
axis = -1
|
||||
elif axis < 0:
|
||||
@ -4594,12 +4594,15 @@ def dequantize(
|
||||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
if compat.forward_compatible(2019, 9, 25) or axis >= 0:
|
||||
if compat.forward_compatible(2019, 10, 22) or axis >= 0 or narrow_range:
|
||||
return gen_array_ops.dequantize(
|
||||
input, min_range, max_range, mode=mode, name=name, axis=axis)
|
||||
input, min_range, max_range, mode=mode, name=name,
|
||||
narrow_range=narrow_range, axis=axis)
|
||||
return gen_array_ops.dequantize(
|
||||
input, min_range, max_range, mode=mode, name=name)
|
||||
|
||||
dequantize.__doc__ = gen_array_ops.dequantize.__doc__
|
||||
|
||||
|
||||
@tf_export("quantization.quantize_and_dequantize")
|
||||
def quantize_and_dequantize(
|
||||
|
@ -31,10 +31,12 @@ class DequantizeOpTest(test.TestCase):
|
||||
def __init__(self, method_name="runTest"):
|
||||
super(DequantizeOpTest, self).__init__(method_name)
|
||||
|
||||
def _testDequantizeOp(self, inputs, min_range, max_range, dtype):
|
||||
def _testDequantizeOp(self, inputs, min_range, max_range, dtype,
|
||||
mode="MIN_COMBINED", narrow_range=False):
|
||||
with self.cached_session():
|
||||
input_op = constant_op.constant(inputs, shape=[len(inputs)], dtype=dtype)
|
||||
dequantized = array_ops.dequantize(input_op, min_range, max_range)
|
||||
dequantized = array_ops.dequantize(input_op, min_range, max_range,
|
||||
mode=mode, narrow_range=narrow_range)
|
||||
tf_ans = self.evaluate(dequantized)
|
||||
|
||||
# TODO(vrv): Add support for DT_QINT32 quantization if needed.
|
||||
@ -44,19 +46,26 @@ class DequantizeOpTest(test.TestCase):
|
||||
dtypes.quint16: np.uint16,
|
||||
dtypes.qint16: np.int16
|
||||
}
|
||||
self.assertTrue(dtype in type_dict.keys())
|
||||
self.assertIn(dtype, type_dict.keys())
|
||||
v_max = np.iinfo(type_dict[dtype]).max
|
||||
v_min = np.iinfo(type_dict[dtype]).min
|
||||
self.assertTrue(min_range >= v_min)
|
||||
self.assertTrue(max_range <= v_max)
|
||||
self.assertGreaterEqual(min_range, v_min)
|
||||
self.assertLessEqual(max_range, v_max)
|
||||
type_range = v_max - v_min
|
||||
if v_min < 0:
|
||||
half_range = (type_range + 1) / 2
|
||||
else:
|
||||
half_range = 0.0
|
||||
|
||||
np_ans = ((inputs.astype(np.float32) + half_range) *
|
||||
(max_range - min_range) / type_range) + min_range
|
||||
if mode == "MIN_COMBINED":
|
||||
if v_min < 0:
|
||||
half_range = (type_range + 1) / 2
|
||||
else:
|
||||
half_range = 0.0
|
||||
np_ans = ((inputs.astype(np.float32) + half_range) *
|
||||
(max_range - min_range) / type_range) + min_range
|
||||
elif mode == "SCALED":
|
||||
if narrow_range:
|
||||
v_min += 1
|
||||
scale_factor = max(min_range / v_min, max_range / v_max)
|
||||
np_ans = inputs.astype(np.float32) * scale_factor
|
||||
|
||||
self.assertAllClose(tf_ans, np_ans, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testBasicQuint8(self):
|
||||
@ -70,6 +79,22 @@ class DequantizeOpTest(test.TestCase):
|
||||
self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8)
|
||||
self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8)
|
||||
|
||||
def testScaledMode(self):
|
||||
self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8,
|
||||
mode="SCALED")
|
||||
self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8,
|
||||
mode="SCALED")
|
||||
self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8,
|
||||
mode="SCALED")
|
||||
|
||||
def testNarrowRange(self):
|
||||
self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8,
|
||||
mode="SCALED", narrow_range=True)
|
||||
self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8,
|
||||
mode="SCALED", narrow_range=True)
|
||||
self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8,
|
||||
mode="SCALED", narrow_range=True)
|
||||
|
||||
def testAxis(self):
|
||||
# Generates a tensor of the specified `shape` using values from `values`
|
||||
# scaled by (slice_idx + 1) along `axis` dimension.
|
||||
|
@ -1106,7 +1106,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize_many_sparse"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.quantization"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "fake_quant_with_min_max_args"
|
||||
|
@ -1026,7 +1026,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'-1\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeserializeIterator"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.quantization"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "fake_quant_with_min_max_args"
|
||||
|
@ -1026,7 +1026,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'-1\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeserializeIterator"
|
||||
|
Loading…
Reference in New Issue
Block a user