Add Reshape and Add support to quantizer.
PiperOrigin-RevId: 236044170
This commit is contained in:
parent
4eec977db7
commit
eadd56121b
@ -560,7 +560,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def calibration_gen():
|
||||
for _ in range(10):
|
||||
yield np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)
|
||||
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
|
||||
|
||||
sess = session.Session()
|
||||
|
||||
|
@ -53,7 +53,10 @@ py_library(
|
||||
py_test(
|
||||
name = "calibrator_test",
|
||||
srcs = ["calibrator_test.py"],
|
||||
data = [":test_data"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
|
@ -64,5 +64,5 @@ class Calibrator(object):
|
||||
"""
|
||||
self._calibrator.Prepare()
|
||||
for calibration_sample in dataset_gen():
|
||||
self._calibrator.FeedTensor([calibration_sample])
|
||||
self._calibrator.FeedTensor(calibration_sample)
|
||||
return self._calibrator.QuantizeModel()
|
||||
|
@ -36,7 +36,23 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
||||
# Input generator for the model.
|
||||
def input_gen():
|
||||
for _ in range(10):
|
||||
yield np.ones(shape=(1, 5, 5, 3), dtype=np.float32)
|
||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
|
||||
def test_calibration_with_quantization_multiple_inputs(self):
|
||||
# Load multi add model from test data.
|
||||
# This model has 4 inputs of size (1, 8, 8, 3).
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'../../testdata/multi_add.bin')
|
||||
float_model = open(model_path, 'rb').read()
|
||||
quantizer = _calibrator.Calibrator(float_model)
|
||||
|
||||
# Input generator for the model.
|
||||
def input_gen():
|
||||
for _ in range(10):
|
||||
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
|
||||
|
||||
quantized_model = quantizer.calibrate_and_quantize(input_gen)
|
||||
self.assertIsNotNone(quantized_model)
|
||||
@ -69,7 +85,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
||||
# Input generator with incorrect shape.
|
||||
def input_gen():
|
||||
for _ in range(10):
|
||||
yield np.ones(shape=(1, 2, 2, 3), dtype=np.float32)
|
||||
yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]
|
||||
|
||||
with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'):
|
||||
quantizer.calibrate_and_quantize(input_gen)
|
||||
|
@ -118,6 +118,7 @@ tf_cc_test(
|
||||
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
|
||||
|
@ -307,24 +307,6 @@ TfLiteStatus SubgraphQuantizer::PropagateMinMaxForAvgAndMaxPool(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphQuantizer::AsymmetricQuantizeSingleInputOutputOp(
|
||||
BuiltinOperator op_code, OperatorT* op) {
|
||||
TF_LITE_ENSURE_EQ(this->error_reporter_, op->inputs.size(), 1);
|
||||
TF_LITE_ENSURE_EQ(this->error_reporter_, op->outputs.size(), 1);
|
||||
|
||||
if (IsSubgraphInput(op->inputs[0])) {
|
||||
TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(op_code, op->inputs[0]));
|
||||
}
|
||||
|
||||
auto output_tensor = subgraph_->tensors[op->outputs[0]].get();
|
||||
if (output_tensor->type != TensorType_FLOAT32) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
auto quant_params = absl::make_unique<QuantizationParametersT>();
|
||||
TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(op_code, op->outputs[0]));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphQuantizer::AsymmetricQuantizeSoftmax(
|
||||
BuiltinOperator op_code, OperatorT* op) {
|
||||
TF_LITE_ENSURE_EQ(this->error_reporter_, op->inputs.size(), 1);
|
||||
@ -346,6 +328,29 @@ TfLiteStatus SubgraphQuantizer::AsymmetricQuantizeSoftmax(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SubgraphQuantizer::AsymmetricQuantizeInputsAndOutputs(
|
||||
BuiltinOperator op_code, OperatorT* op) {
|
||||
TF_LITE_ENSURE(this->error_reporter_, !op->inputs.empty());
|
||||
TF_LITE_ENSURE(this->error_reporter_, !op->outputs.empty());
|
||||
for (size_t input_idx = 0; input_idx < op->inputs.size(); ++input_idx) {
|
||||
auto input_tensor = subgraph_->tensors[op->inputs[input_idx]].get();
|
||||
if (IsSubgraphInput(op->inputs[input_idx]) &&
|
||||
input_tensor->type == TensorType_FLOAT32) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
AsymmetricQuantizeTensor(op_code, op->inputs[input_idx]));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t output_idx = 0; output_idx < op->outputs.size(); ++output_idx) {
|
||||
auto output_tensor = subgraph_->tensors[op->outputs[output_idx]].get();
|
||||
if (output_tensor->type == TensorType_FLOAT32) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
AsymmetricQuantizeTensor(op_code, op->outputs[output_idx]));
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool SubgraphQuantizer::IsSubgraphInput(int32_t tensor_idx) const {
|
||||
return std::find(subgraph_->inputs.begin(), subgraph_->inputs.end(),
|
||||
tensor_idx) != subgraph_->inputs.end();
|
||||
@ -363,7 +368,9 @@ TfLiteStatus SubgraphQuantizer::QuantizeOperator(int op_idx) {
|
||||
case BuiltinOperator_MAX_POOL_2D:
|
||||
return PropagateMinMaxForAvgAndMaxPool(op_code, op);
|
||||
case BuiltinOperator_SQUEEZE:
|
||||
return AsymmetricQuantizeSingleInputOutputOp(op_code, op);
|
||||
case BuiltinOperator_RESHAPE:
|
||||
case BuiltinOperator_ADD:
|
||||
return AsymmetricQuantizeInputsAndOutputs(op_code, op);
|
||||
case BuiltinOperator_SOFTMAX:
|
||||
return AsymmetricQuantizeSoftmax(op_code, op);
|
||||
default:
|
||||
|
@ -46,17 +46,16 @@ class SubgraphQuantizer {
|
||||
TfLiteStatus PropagateMinMaxForAvgAndMaxPool(BuiltinOperator op_code,
|
||||
OperatorT* op);
|
||||
|
||||
// Asymmetric quantizes inputs and outputs of an Op that has single input and
|
||||
// single output. E.g. Squeeze.
|
||||
TfLiteStatus AsymmetricQuantizeSingleInputOutputOp(BuiltinOperator op_code,
|
||||
OperatorT* op);
|
||||
|
||||
// Asymmetric quantizes inputs and outputs of an Softmax Op.
|
||||
// Input is quantized with the min-max range and output is hardcoded to have
|
||||
// 1/256 as scale and -128 as zero point.
|
||||
TfLiteStatus AsymmetricQuantizeSoftmax(BuiltinOperator op_code,
|
||||
OperatorT* op);
|
||||
|
||||
// Asymmetric quantizes an Op with multiple inputs and outputs. E.g Add.
|
||||
TfLiteStatus AsymmetricQuantizeInputsAndOutputs(BuiltinOperator op_code,
|
||||
OperatorT* op);
|
||||
|
||||
TfLiteStatus AsymmetricQuantizeTensor(BuiltinOperator op_code,
|
||||
int32_t tensor_idx);
|
||||
|
||||
|
@ -53,6 +53,10 @@ std::unique_ptr<FlatBufferModel> ReadAvgPoolModel() {
|
||||
return ReadModel(kSingleAvgPoolModelMinMinus5MaxPlus5);
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> ReadMultiInputAddWithReshapeModel() {
|
||||
return ReadModel(kMultiInputAddWithReshape);
|
||||
}
|
||||
|
||||
TEST(SubgraphQuantizerTest, VerifyConvQuantizationWithUnitScale) {
|
||||
ASSERT_TRUE(g_test_model_dir);
|
||||
ASSERT_FALSE(g_test_model_dir->empty());
|
||||
@ -378,6 +382,120 @@ TEST(SubgraphQuantizerTest, VerifyAvgPoolQuantization) {
|
||||
EXPECT_EQ(input_quant_params->scale[0], output_quant_params->scale[0]);
|
||||
}
|
||||
|
||||
TEST(SubgraphQuantizerTest, VerifyReshapeQuantization) {
|
||||
ASSERT_TRUE(g_test_model_dir);
|
||||
ASSERT_FALSE(g_test_model_dir->empty());
|
||||
auto test_model = ReadMultiInputAddWithReshapeModel();
|
||||
ASSERT_TRUE(test_model);
|
||||
auto readonly_model = test_model->GetModel();
|
||||
ASSERT_TRUE(readonly_model);
|
||||
ASSERT_TRUE(readonly_model->subgraphs());
|
||||
ASSERT_GE(readonly_model->subgraphs()->size(), 1);
|
||||
tflite::ModelT model;
|
||||
readonly_model->UnPackTo(&model);
|
||||
auto subgraph = model.subgraphs[0].get();
|
||||
FailOnErrorReporter error_reporter;
|
||||
SubgraphQuantizer quantizer(&model, subgraph, &error_reporter);
|
||||
// 2 operators RESHAPE and ADD
|
||||
ASSERT_EQ(subgraph->operators.size(), 2);
|
||||
auto status = quantizer.QuantizeOperator(0);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
status = quantizer.QuantizeOperator(1);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify Reshape is quantized.
|
||||
auto op = subgraph->operators[1].get();
|
||||
ASSERT_EQ(model.operator_codes[op->opcode_index].get()->builtin_code,
|
||||
BuiltinOperator_RESHAPE);
|
||||
|
||||
ASSERT_EQ(op->inputs.size(), 2);
|
||||
ASSERT_EQ(op->outputs.size(), 1);
|
||||
|
||||
auto float_graph = readonly_model->subgraphs()->Get(0);
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
|
||||
TensorType_FLOAT32);
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
|
||||
TensorType_FLOAT32);
|
||||
|
||||
EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8);
|
||||
EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
|
||||
|
||||
auto float_input_quant_params =
|
||||
float_graph->tensors()->Get(op->inputs[0])->quantization();
|
||||
auto input_quant_params =
|
||||
subgraph->tensors[op->inputs[0]]->quantization.get();
|
||||
VerifyAsymmetricQuantizationScale(*float_input_quant_params,
|
||||
*input_quant_params);
|
||||
|
||||
auto float_output_quant_params =
|
||||
float_graph->tensors()->Get(op->outputs[0])->quantization();
|
||||
auto output_quant_params =
|
||||
subgraph->tensors[op->outputs[0]]->quantization.get();
|
||||
ASSERT_EQ(float_output_quant_params->min()->size(), 1);
|
||||
ASSERT_EQ(float_output_quant_params->max()->size(), 1);
|
||||
ASSERT_EQ(output_quant_params->min.size(), 1);
|
||||
ASSERT_EQ(output_quant_params->max.size(), 1);
|
||||
}
|
||||
|
||||
TEST(SubgraphQuantizerTest, VerifyAddQuantization) {
|
||||
ASSERT_TRUE(g_test_model_dir);
|
||||
ASSERT_FALSE(g_test_model_dir->empty());
|
||||
auto test_model = ReadMultiInputAddWithReshapeModel();
|
||||
ASSERT_TRUE(test_model);
|
||||
auto readonly_model = test_model->GetModel();
|
||||
ASSERT_TRUE(readonly_model);
|
||||
ASSERT_TRUE(readonly_model->subgraphs());
|
||||
ASSERT_GE(readonly_model->subgraphs()->size(), 1);
|
||||
tflite::ModelT model;
|
||||
readonly_model->UnPackTo(&model);
|
||||
auto subgraph = model.subgraphs[0].get();
|
||||
FailOnErrorReporter error_reporter;
|
||||
SubgraphQuantizer quantizer(&model, subgraph, &error_reporter);
|
||||
// 2 operators RESHAPE and ADD
|
||||
ASSERT_EQ(subgraph->operators.size(), 2);
|
||||
auto status = quantizer.QuantizeOperator(0);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
status = quantizer.QuantizeOperator(1);
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
|
||||
// Verify ADD is quantized.
|
||||
auto op = subgraph->operators[0].get();
|
||||
ASSERT_EQ(model.operator_codes[op->opcode_index].get()->builtin_code,
|
||||
BuiltinOperator_ADD);
|
||||
|
||||
ASSERT_EQ(op->inputs.size(), 2);
|
||||
ASSERT_EQ(op->outputs.size(), 1);
|
||||
|
||||
auto float_graph = readonly_model->subgraphs()->Get(0);
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
|
||||
TensorType_FLOAT32);
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[1])->type(),
|
||||
TensorType_FLOAT32);
|
||||
ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
|
||||
TensorType_FLOAT32);
|
||||
|
||||
for (size_t input_idx = 0; input_idx < 2; ++input_idx) {
|
||||
EXPECT_EQ(subgraph->tensors[op->inputs[input_idx]].get()->type,
|
||||
TensorType_INT8);
|
||||
auto float_input_quant_params =
|
||||
float_graph->tensors()->Get(op->inputs[input_idx])->quantization();
|
||||
auto input_quant_params =
|
||||
subgraph->tensors[op->inputs[input_idx]]->quantization.get();
|
||||
VerifyAsymmetricQuantizationScale(*float_input_quant_params,
|
||||
*input_quant_params);
|
||||
}
|
||||
|
||||
EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8);
|
||||
auto float_output_quant_params =
|
||||
float_graph->tensors()->Get(op->outputs[0])->quantization();
|
||||
auto output_quant_params =
|
||||
subgraph->tensors[op->outputs[0]]->quantization.get();
|
||||
ASSERT_EQ(float_output_quant_params->min()->size(), 1);
|
||||
ASSERT_EQ(float_output_quant_params->max()->size(), 1);
|
||||
ASSERT_EQ(output_quant_params->min.size(), 1);
|
||||
ASSERT_EQ(output_quant_params->max.size(), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace optimize
|
||||
|
@ -33,6 +33,8 @@ const char* kSingleAvgPoolModelMinMinus5MaxPlus5 =
|
||||
|
||||
const char* kModelWithSharedWeights = "weight_shared_between_convs.bin";
|
||||
|
||||
const char* kMultiInputAddWithReshape = "multi_input_add_reshape.bin";
|
||||
|
||||
int FailOnErrorReporter::Report(const char* format, va_list args) {
|
||||
char buf[1024];
|
||||
vsnprintf(buf, sizeof(buf), format, args);
|
||||
|
@ -46,6 +46,9 @@ extern const char* kSingleAvgPoolModelMinMinus5MaxPlus5;
|
||||
// and an add operation.
|
||||
extern const char* kModelWithSharedWeights;
|
||||
|
||||
// Test model with Add followed by a reshape. Model has 2 inputs for add.
|
||||
extern const char* kMultiInputAddWithReshape;
|
||||
|
||||
// An error reporter that fails on testing.
|
||||
class FailOnErrorReporter : public ErrorReporter {
|
||||
public:
|
||||
|
@ -21,5 +21,7 @@ This directory contains test models for testing quantization.
|
||||
* `single_avg_pool_input_min_minus_5_max_5.bin` \
|
||||
A floating point model with a single average pool. The input tensor has min
|
||||
and max in range [-5, 5], not necessarily -5 or +5.
|
||||
* `weight_shared_between_convs.tflite` \
|
||||
* `weight_shared_between_convs.bin` \
|
||||
A floating point model with two convs that have a use the same weight tensor.
|
||||
* `multi_input_add_reshape.bin` \
|
||||
A floating point model with two inputs with an add followed by a reshape.
|
||||
|
BIN
tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user