diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index ff5ebbd979b..1f57c5c9d80 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -219,6 +219,8 @@ tf_cc_test( "//tensorflow/lite/tools/optimize:testdata/argmax.bin", "//tensorflow/lite/tools/optimize:testdata/concat.bin", "//tensorflow/lite/tools/optimize:testdata/fc.bin", + "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", + "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", "//tensorflow/lite/tools/optimize:testdata/mixed.bin", "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index fd7459472d1..eec8bacea23 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -168,9 +168,73 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, case BuiltinOperator_LSTM: { // TODO(jianlijianli): extend LSTM op spec to inlucde input, bias etc. // TODO(jianlijianli): extend this to other variants of LSTM. - // LSTM need 5 intermediate tensors. This agrees with the fully quantized + // LSTM needs 5 intermediate tensors. This agrees with the fully quantized // kernels in lstm_eval.cc - property.intermediates = {{0, {}}, {1, {}}, {2, {}}, {3, {}}, {4, {}}}; + static const float alpha = static_cast(std::pow(2, -10)); + + TensorProperty tensor_property_12; + tensor_property_12.use_derived_scale = true; + tensor_property_12.number_of_bits = 32; + tensor_property_12.derived_scale = {{20}, {}, {alpha}}; + TensorProperty tensor_property_13; + tensor_property_13.use_derived_scale = true; + tensor_property_13.number_of_bits = 32; + tensor_property_13.derived_scale = {{21}, {}, {alpha}}; + TensorProperty tensor_property_14; + tensor_property_14.use_derived_scale = true; + tensor_property_14.number_of_bits = 32; + tensor_property_14.derived_scale = {{22}, {}, {alpha}}; + TensorProperty tensor_property_15; + tensor_property_15.use_derived_scale = true; + tensor_property_15.number_of_bits = 32; + tensor_property_15.derived_scale = {{23}, {}, {alpha}}; + TensorProperty tensor_property_17; + tensor_property_17.use_derived_scale = true; + tensor_property_17.number_of_bits = 32; + tensor_property_17.derived_scale = {{16}, {4}, {}}; + TensorProperty tensor_property_19; + tensor_property_19.extend_to_power_of_two = true; + tensor_property_19.number_of_bits = 16; + tensor_property_19.state_tensor = true; + tensor_property_19.symmetric = true; + TensorProperty tensor_property_20; + tensor_property_20.number_of_bits = 16; + tensor_property_20.symmetric = true; + + property.inputs = { + {0, {}}, + {1, {}}, + {2, {}}, + {3, {}}, + {4, {}}, + {5, {}}, + {6, {}}, + {7, {}}, + {8, {}}, + {9, {}}, + {10, {}}, + {11, {}}, + {16, {}}, + {19, tensor_property_19}, + {20, tensor_property_20}, + {21, tensor_property_20}, + {22, tensor_property_20}, + {23, tensor_property_20}, + {12, tensor_property_12}, + {13, tensor_property_13}, + {14, tensor_property_14}, + {15, tensor_property_15}, + {17, tensor_property_17}, + }; + property.outputs = {{0, {}}}; + property.intermediates = { + {0, tensor_property_20}, + {1, tensor_property_20}, + {2, tensor_property_20}, + {3, tensor_property_20}, + {4, {}}, + }; + property.restrict_scale = {{18, 0}}; property.version = 2; break; } diff --git a/tensorflow/lite/tools/optimize/operator_property.h b/tensorflow/lite/tools/optimize/operator_property.h index 92edf72624f..5d37aa304e5 100644 --- a/tensorflow/lite/tools/optimize/operator_property.h +++ b/tensorflow/lite/tools/optimize/operator_property.h @@ -44,6 +44,21 @@ struct TensorProperty { bool restriction = false; // scale/zero_point hardcoded. std::pair restricted_value = {0.0, 0}; + + // Use derived scale. + bool use_derived_scale = false; + // The derived scale. + DerivedScale derived_scale; + + // The number of bits for this tensor. It could be 8, 16, 32 or even not power + // of two. + int number_of_bits = 8; + + // Extend the range to power of two. + bool extend_to_power_of_two = false; + + // State tensor. + bool state_tensor = false; }; struct OperatorProperty { @@ -55,10 +70,13 @@ struct OperatorProperty { // Op has arbitrary number of outputs, such as slice. bool arbitrary_outputs = false; // Input indexes -> input tensor property. + // Must be topologically sorted since there are derived scales. std::vector> inputs = {}; // Output indexes -> output tensor property. std::vector> outputs = {}; // Bias indexes. + // TODO(jianlijianli): remove this by putting biases into inputs as well since + // we now can model "derived scale". std::vector biases = {}; // Intermediate indexes -> intermediate tensor property. @@ -67,6 +85,12 @@ struct OperatorProperty { // Force output to reuse the same scale and zero point of input. bool restrict_same_input_output_scale = false; + // Use same min of min and max of max for each group. + // Incompatable with restrict_same_input_output_scale and restricted_value. + // TODO(jianlijianli): make it compatible with other restrictions when there + // is a use case. + std::vector> restrict_scale = {}; + // Op version. int version = 1; }; diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 314fc42cbea..7a6f6cb634d 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -625,11 +625,15 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx, float scale = 1.0f; OperatorT* op = subgraph->operators[op_idx].get(); for (int i = 0; i < input_index.size(); ++i) { - TensorT* tensor = subgraph->tensors[op->inputs[i]].get(); + const int index_local = input_index[i]; + const int index_global = op->inputs[index_local]; + const TensorT* tensor = subgraph->tensors[index_global].get(); scale *= tensor->quantization->scale[0]; } for (int i = 0; i < intermediate_index.size(); ++i) { - TensorT* tensor = subgraph->tensors[op->intermediates[i]].get(); + const int index_local = intermediate_index[i]; + const int index_global = op->intermediates[index_local]; + const TensorT* tensor = subgraph->tensors[index_global].get(); scale *= tensor->quantization->scale[0]; } for (int i = 0; i < factors.size(); ++i) { @@ -646,6 +650,15 @@ void QuantizeActivation(TensorT* tensor) { tensor->type = TensorType_INT8; } +TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale) { + const int32 zero_point = 0; + tensor->quantization = absl::make_unique(); + tensor->quantization->scale.push_back(scale); + tensor->quantization->zero_point.push_back(zero_point); + tensor->type = TensorType_INT16; + return kTfLiteOk; +} + int GetPowerOfTwoScale(float min, float max) { const float range = std::max(std::abs(min), std::abs(max)); int pot = 0; diff --git a/tensorflow/lite/tools/optimize/quantization_utils.h b/tensorflow/lite/tools/optimize/quantization_utils.h index 060e0c283be..18ed707e175 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_utils.h @@ -138,6 +138,9 @@ float GetEffectiveScale(ModelT* model, SubGraphT* subgraph, int op_idx, // Quantize activation. void QuantizeActivation(TensorT* tensor); +// Quantize activation to 16bit. +TfLiteStatus QuantizeActivationToInt16(TensorT* tensor, float scale); + // Get the power of two scale for min and max for symmetric quantization case. int GetPowerOfTwoScale(float min, float max); diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index 99cde4ccf63..cc14a48ae69 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -421,6 +421,10 @@ TfLiteStatus QuantizeOpInput( return kTfLiteError; } const int32_t tensor_idx = op->inputs[input_idx]; + if (tensor_idx == -1) { + // Skip optional tensor. + return kTfLiteOk; + } TensorT* tensor = subgraph->tensors[tensor_idx].get(); // Assumes op is quantized to int8. const bool is_input_quantized = utils::QuantizationParametersExist(tensor); @@ -429,9 +433,59 @@ TfLiteStatus QuantizeOpInput( if (utils::HasBuffer(model, subgraph, tensor_idx)) { // TODO(suharshs): Look at consumers, throw error if one consumer is // per-channel and one per-layer. - if (utils::QuantizeWeight(model, tensor, tensor_property.per_axis, - tensor_property.per_axis_index, - error_reporter) == kTfLiteError) { + if (tensor_property.number_of_bits == 8) { + if (tensor_property.use_derived_scale) { + // Currently 8bit tensors in input do not accept derived scale. + return kTfLiteError; + } + if (utils::QuantizeWeight(model, tensor, tensor_property.per_axis, + tensor_property.per_axis_index, + error_reporter) == kTfLiteError) { + error_reporter->Report( + "Unable to quantize buffer or min/max value for input %d " + "in op %s in subgraph %d, node: %d", + input_idx, EnumNameBuiltinOperator(op_code), subgraph_idx, + *op_idx); + return kTfLiteError; + } + } else if (tensor_property.number_of_bits == 16) { + if (tensor_property.use_derived_scale) { + // Currently 16bit tensors in input do not accept derived scale. + return kTfLiteError; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + int total_size = 1; + for (int i = 0; i < tensor->shape.size(); ++i) { + total_size *= tensor->shape[i]; + } + BufferT* buffer = model->buffers[tensor->buffer].get(); + float* float_data = reinterpret_cast(buffer->data.data()); + auto minmax = std::minmax_element(float_data, float_data + total_size); + const float min = *minmax.first; + const float max = *minmax.second; + const float range = std::max(std::abs(min), std::abs(max)); + // The narrow range quantized value for int16. + const float quantize_range = 32767.0; + const float scale = range / quantize_range; + return utils::SymmetricQuantizeFloatsToInt16(model, tensor, scale, + error_reporter); + } else if (tensor_property.number_of_bits == 32) { + if (!tensor_property.use_derived_scale) { + // Currently 32 bit tensors in input only accept derived scale. + return kTfLiteError; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + const float scale = utils::GetEffectiveScale( + model, subgraph, *op_idx, + tensor_property.derived_scale.input_tensors, + tensor_property.derived_scale.intermediate_tensors, + tensor_property.derived_scale.factors); + return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale, + error_reporter); + + } else { + // Only 8, 16, 32 are supported. + // TODO(jianlijianli): extend this to support arbitrary bits. error_reporter->Report( "Unable to quantize buffer or min/max value for input %d " "in op %s in subgraph %d, node: %d", @@ -439,9 +493,27 @@ TfLiteStatus QuantizeOpInput( return kTfLiteError; } } else if (utils::HasMinMax(tensor)) { - // TODO(suharshs): Handle per-channel dynamic tensor. - if (IsSubgraphInput(subgraph, tensor_idx)) { - utils::QuantizeActivation(tensor); + if (IsSubgraphInput(subgraph, tensor_idx) || + tensor_property.state_tensor) { + if (tensor_property.number_of_bits == 8) { + if (tensor_property.use_derived_scale) { + // Currently 8bit tensors in input do not accept derived scale. + return kTfLiteError; + } + utils::QuantizeActivation(tensor); + } else if (tensor_property.number_of_bits == 16) { + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + float range = std::max(std::abs(tensor->quantization->min[0]), + std::abs(tensor->quantization->max[0])); + if (tensor_property.extend_to_power_of_two) { + const int power_of_two_scale = utils::GetPowerOfTwoScale( + tensor->quantization->min[0], tensor->quantization->max[0]); + range = std::pow(2, power_of_two_scale); + } + const float quantized_range = 32768.0; + const float scale = range / quantized_range; + utils::QuantizeActivationToInt16(tensor, scale); + } } else { // If the tensor is not a model input, we need to add a Quantize // operation since the preceding op may require a float output. @@ -515,6 +587,10 @@ TfLiteStatus QuantizeOpOutput( } TensorT* output_tensor = subgraph->tensors[op->outputs[output_idx]].get(); + if (utils::QuantizationParametersExist(output_tensor)) { + // Skip output if it has been quantized. + return kTfLiteOk; + } if (ShouldRestrictSameInputOutputScale(property)) { // Copy quantization parameter. For average pool, max pool, etc // min/max can be different but we want them to be the same. @@ -576,6 +652,122 @@ TfLiteStatus QuantizeOpOutput( return kTfLiteOk; } +TfLiteStatus QuantizeIntemediateTensors(ModelT* model, + ErrorReporter* error_reporter) { + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); + subgraph_idx++) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); + for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) { + operator_property::OperatorProperty property = + operator_property::GetOperatorProperty(model, subgraph_idx, op_idx); + if (!property.intermediates.empty()) { + OperatorT* op = subgraph->operators[op_idx].get(); + const BuiltinOperator op_code = + model->operator_codes[op->opcode_index]->builtin_code; + for (const std::pair& input : + property.intermediates) { + const int index_local = input.first; + const int index_global = op->intermediates[index_local]; + if (index_global == -1) { + // Skip optional tensor. + continue; + } + if (input.second.number_of_bits == 8 && + input.second.symmetric == false) { + TensorT* tensor = subgraph->tensors[index_global].get(); + if (utils::HasMinMax(tensor)) { + utils::QuantizeActivation(tensor); + } else { + error_reporter->Report( + "Unable to find min/max value for output %d in %s in " + "subgraph %d, node: %d", + tensor, EnumNameBuiltinOperator(op_code), subgraph_idx, + op_idx); + return kTfLiteError; + } + } else if (input.second.number_of_bits == 16 && + input.second.symmetric == true) { + TensorT* tensor = subgraph->tensors[index_global].get(); + if (tensor->quantization == nullptr) { + continue; + } + const float min = tensor->quantization->min[0]; + const float max = tensor->quantization->max[0]; + const float range = std::max(std::abs(min), std::abs(max)); + if (range < 1e-8) { + return kTfLiteError; + } + + // Get scale and zero point. + const float quantized_range = 32767.0; + const float scale = range / quantized_range; + utils::QuantizeActivationToInt16(tensor, scale); + } else { + return kTfLiteError; + } + } + } + } + } + return kTfLiteOk; +} + +// Quantize tensros that have shared range. For example, in LSTM, the output +// tensor and input state tensor should share the same range because they are +// using the same scale and zero point. +TfLiteStatus QuantizeSharedRange(ModelT* model, ErrorReporter* error_reporter) { + for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); + subgraph_idx++) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); + for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) { + operator_property::OperatorProperty property = + operator_property::GetOperatorProperty(model, subgraph_idx, op_idx); + if (!property.intermediates.empty()) { + OperatorT* op = subgraph->operators[op_idx].get(); + for (const std::vector& input : property.restrict_scale) { + if (input.empty()) { + continue; + } + // Currently only support pair of twos. + // TODO(jianlijianli): extend to arbitrary number of tensors. + if (input.size() != 2) { + return kTfLiteError; + } + const int index_1 = input[0]; + const int index_2 = input[1]; + // TODO(jianlijianli): model input/output. + TensorT* tensor_1 = subgraph->tensors[op->inputs[index_1]].get(); + TensorT* tensor_2 = subgraph->tensors[op->outputs[index_2]].get(); + const float min_of_min = std::min(tensor_1->quantization->min[0], + tensor_2->quantization->min[0]); + const float max_of_max = std::max(tensor_1->quantization->max[0], + tensor_2->quantization->max[0]); + if (min_of_min == 0.0 && max_of_max == 0.0) { + return kTfLiteError; + } + + // Asmmetric quantization to 8 bit. + auto quantization_params = + absl::make_unique(); + utils::GetAsymmetricQuantizationParams( + min_of_min, max_of_max, -128, 127, quantization_params.get()); + + // Populate both tensors with the same parameters. + const float scale = quantization_params->scale[0]; + const int32 zero_point = quantization_params->zero_point[0]; + for (TensorT* tensor : {tensor_1, tensor_2}) { + tensor->quantization = absl::make_unique(); + tensor->quantization->scale.push_back(scale); + tensor->quantization->zero_point.push_back(zero_point); + tensor->type = TensorType_INT8; + } + } + } + } + } + return kTfLiteOk; +} + // Quantize inputs and weights. // Because of ops such as lstm, still need to do per op, instead of weights. TfLiteStatus QuantizeWeightsInputOutput( @@ -713,6 +905,10 @@ TfLiteStatus FillQuantizationParams( // Get tensor. const int32_t input_idx = input.first; const int32_t tensor_idx = op->inputs[input_idx]; + if (tensor_idx == -1) { + // Skip optional tensor. + continue; + } TensorT* tensor = subgraph->tensors[tensor_idx].get(); // Static tensor. @@ -918,6 +1114,8 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, FillQuantizationParams(model, operator_names, error_reporter)); TF_LITE_ENSURE_STATUS( EnsureBiasScaleCompatibility(model, operator_names, error_reporter)); + TF_LITE_ENSURE_STATUS(QuantizeIntemediateTensors(model, error_reporter)); + TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput( model, allow_float, operator_names, error_reporter)); TF_LITE_ENSURE_STATUS( diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 679b681c49e..247e4479c28 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -979,6 +979,53 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { EXPECT_EQ(model_.operator_codes[0]->version, 2); } +class QuantizeLSTMTest : public QuantizeModelTest { + protected: + QuantizeLSTMTest() { + input_model_ = ReadModel(internal::kLstmCalibrated); + readonly_model_ = input_model_->GetModel(); + readonly_model_->UnPackTo(&model_); + } +}; + +TEST_F(QuantizeLSTMTest, VerifyLSTM) { + // Quantize model. + auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, + TensorType_FLOAT32, &error_reporter_); + ASSERT_EQ(kTfLiteOk, status); + + // Read expected model. + auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_read_only_model = expected_fb_model->GetModel(); + ModelT expected_model; + expected_read_only_model->UnPackTo(&expected_model); + + // Comparison. + ASSERT_EQ(model_.subgraphs.size(), expected_model.subgraphs.size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size(); + subgraph_idx++) { + const auto graph = model_.subgraphs[subgraph_idx].get(); + const auto expected_graph = expected_model.subgraphs[subgraph_idx].get(); + ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size()); + for (size_t i = 0; i < graph->tensors.size(); i++) { + const auto tensor = graph->tensors[i].get(); + const auto expected_tensor = expected_graph->tensors[i].get(); + EXPECT_EQ(tensor->buffer, expected_tensor->buffer); + EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable); + EXPECT_EQ(tensor->shape, expected_tensor->shape); + EXPECT_EQ(tensor->name, expected_tensor->name); + EXPECT_EQ(tensor->type, expected_tensor->type); + } + } + ASSERT_EQ(model_.buffers.size(), expected_model.buffers.size()); + for (size_t buffer_idx = 0; buffer_idx < model_.buffers.size(); + ++buffer_idx) { + const auto buffer = model_.buffers[buffer_idx].get()->data; + const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data; + EXPECT_EQ(buffer, expected_buffer); + } +} + class QuantizeFCTest : public QuantizeModelTest { protected: QuantizeFCTest() { diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc index 3cfd5f5b701..74524a18081 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/lite/tools/optimize/test_util.cc @@ -49,6 +49,9 @@ const char* kModelMixed = "mixed.bin"; const char* kModelSplit = "split.bin"; +const char* kLstmCalibrated = "lstm_calibrated.bin"; +const char* kLstmQuantized = "lstm_quantized.bin"; + int FailOnErrorReporter::Report(const char* format, va_list args) { char buf[1024]; vsnprintf(buf, sizeof(buf), format, args); diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h index bf42a30b99a..12c46aa882b 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/lite/tools/optimize/test_util.h @@ -76,6 +76,10 @@ extern const char* kModelMixed; // Test model with split op. extern const char* kModelSplit; +// Test model with LSTM op. +extern const char* kLstmCalibrated; +extern const char* kLstmQuantized; + // An error reporter that fails on testing. class FailOnErrorReporter : public ErrorReporter { public: diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin b/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin new file mode 100644 index 00000000000..97a6a5732af Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin differ diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin b/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin new file mode 100644 index 00000000000..336c2b80005 Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin differ