Enable int8 quantization for DEPTH_TO_SPACE in Post-Training Quantization

tooling.
Fixes #45213

PiperOrigin-RevId: 346416345
Change-Id: I67a0b800a107c3db15460815366fb33d417bcf68
This commit is contained in:
Karim Nosir 2020-12-08 14:53:24 -08:00 committed by TensorFlower Gardener
parent c9cf71a50d
commit 5ff53a1d66
6 changed files with 29 additions and 8 deletions

View File

@ -61,6 +61,7 @@
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op.
* Added DEPTH_TO_SPACE support in Post training quantization.
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
only supports float32 input.
* TFLite Supports SingatureDef:

View File

@ -130,7 +130,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE());
AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
/* min_version = */ 1,
/* max_version = */ 4);

View File

@ -28,9 +28,15 @@ def make_depth_to_space_tests(options):
"""Make a set of tests to do depth_to_space."""
test_parameters = [{
"dtype": [tf.float32, tf.int32, tf.uint8, tf.int64],
"dtype": [tf.int32, tf.uint8, tf.int64],
"input_shape": [[2, 3, 4, 16]],
"block_size": [2, 4],
"fully_quantize": [False],
}, {
"dtype": [tf.float32],
"input_shape": [[2, 3, 4, 16]],
"block_size": [2, 4],
"fully_quantize": [True, False],
}]
def build_graph(parameters):
@ -43,8 +49,15 @@ def make_depth_to_space_tests(options):
return [input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):
if not parameters["fully_quantize"]:
input_values = create_tensor_data(parameters["dtype"],
parameters["input_shape"])
else:
input_values = create_tensor_data(
parameters["dtype"],
parameters["input_shape"],
min_value=-1,
max_value=1)
return [input_values], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_values])))

View File

@ -106,6 +106,13 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
property.version = 2;
property.quantizable_int16 = false;
break;
case BuiltinOperator_DEPTH_TO_SPACE:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.version = 2;
property.quantizable_int16 = false;
break;
case BuiltinOperator_SPLIT:
// We skip input 0 since it is the split dim which is not real valued.
property.inputs = {{1, {}}};

View File

@ -620,10 +620,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_SELECT:
case BuiltinOperator_RSQRT:
case BuiltinOperator_SQUARED_DIFFERENCE:
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_DEPTH_TO_SPACE:
case BuiltinOperator_MIRROR_PAD:
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;

View File

@ -100,6 +100,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},
{{BuiltinOperator_DEPTH_TO_SPACE, 1}, "2.1.0"},
{{BuiltinOperator_DEPTH_TO_SPACE, 2}, kPendingReleaseVersion},
{{BuiltinOperator_EMBEDDING_LOOKUP, 1}, "1.13.0"},
{{BuiltinOperator_EMBEDDING_LOOKUP, 2}, "1.14.0"},
{{BuiltinOperator_EMBEDDING_LOOKUP, 3}, "1.14.0"},