From 3742faa90cdda4290987736e6d68aa27ad8c3845 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 May 2019 14:43:40 -0700 Subject: [PATCH] Support quantization to float16 PiperOrigin-RevId: 248790891 --- tensorflow/lite/kernels/BUILD | 1 + tensorflow/lite/kernels/dequantize.cc | 17 ++- tensorflow/lite/kernels/internal/BUILD | 2 + .../internal/reference/reference_ops.h | 10 ++ tensorflow/lite/toco/args.h | 1 + tensorflow/lite/toco/tflite/export.cc | 19 ++- tensorflow/lite/toco/tflite/export.h | 10 +- tensorflow/lite/toco/tflite/export_test.cc | 4 +- tensorflow/lite/toco/toco_cmdline_flags.cc | 6 + tensorflow/lite/toco/toco_flags.proto | 11 +- tensorflow/lite/toco/toco_tooling.cc | 12 +- tensorflow/lite/tools/optimize/BUILD | 2 + .../lite/tools/optimize/quantization_utils.cc | 44 +++++- .../lite/tools/optimize/quantization_utils.h | 3 + .../tools/optimize/quantization_utils_test.cc | 32 +++++ .../lite/tools/optimize/quantize_weights.cc | 125 ++++++++++++++---- .../lite/tools/optimize/quantize_weights.h | 8 +- .../tools/optimize/quantize_weights_test.cc | 61 ++++++++- 18 files changed, 322 insertions(+), 46 deletions(-) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 95f554adcfa..33351093968 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -395,6 +395,7 @@ cc_library( "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor_utils", + "//third_party/eigen3", "@farmhash_archive//:farmhash", "@flatbuffers", ], diff --git a/tensorflow/lite/kernels/dequantize.cc b/tensorflow/lite/kernels/dequantize.cc index 7f03c73c9c9..7c17cae7607 100644 --- a/tensorflow/lite/kernels/dequantize.cc +++ b/tensorflow/lite/kernels/dequantize.cc @@ -12,13 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" + #include + #include +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" @@ -59,7 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || - op_context.input->type == kTfLiteInt8); + op_context.input->type == kTfLiteInt8 || + op_context.input->type == kTfLiteFloat16); op_context.output->type = kTfLiteFloat32; // If the input tensor is constant, we can persist the dequantized value in @@ -96,6 +101,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(op_context.output), GetTensorData(op_context.output)); break; + case kTfLiteFloat16: { + const Eigen::half* half_data = reinterpret_cast( + GetTensorData(op_context.input)); + reference_ops::Dequantize(GetTensorShape(op_context.input), half_data, + GetTensorShape(op_context.output), + GetTensorData(op_context.output)); + break; + } default: context->ReportError(context, "Type %d not supported.", op_context.input->type); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 5bafcdc00ce..a908e3c4b65 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -385,6 +385,7 @@ cc_library( ":types", "@gemmlowp//:fixedpoint", "@gemmlowp//:profiler", + "//third_party/eigen3", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:op_macros", ] + select({ @@ -421,6 +422,7 @@ cc_library( ":legacy_types", ":tensor", ":types", + "//third_party/eigen3", "@gemmlowp", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:op_macros", diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 8488f7ae266..f4dc0c1b828 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/Eigen/Core" #include "fixedpoint/fixedpoint.h" #include "profiling/instrumentation.h" #include "tensorflow/lite/c/c_api_internal.h" @@ -2491,6 +2492,15 @@ inline void Dequantize(const tflite::DequantizationParams& op_params, } } +inline void Dequantize(const RuntimeShape& input_shape, + const Eigen::half* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; i++) { + output_data[i] = Eigen::half_impl::half_to_float(input_data[i]); + } +} + template inline void AffineQuantize(const tflite::QuantizationParams& op_params, const RuntimeShape& input_shape, diff --git a/tensorflow/lite/toco/args.h b/tensorflow/lite/toco/args.h index c6eeb2859a9..1003a157e42 100644 --- a/tensorflow/lite/toco/args.h +++ b/tensorflow/lite/toco/args.h @@ -172,6 +172,7 @@ struct ParsedTocoFlags { Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); Arg post_training_quantize = Arg(false); + Arg quantize_to_float16 = Arg(false); // Deprecated flags Arg quantize_weights = Arg(false); Arg input_type; diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 05d889be5f0..1d816d489a9 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -606,7 +606,9 @@ tensorflow::Status Export( builder.CreateVector(subgraphs), description, buffers); ::tflite::FinishModelBuffer(builder, new_model_location); - if (params.quantize_weights) { + if (params.quantize_weights == QuantizedBufferType::NONE) { + WriteModelToString(builder, output_file_contents); + } else { // Call the quantize_weights tool. LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. " "dump_graphviz will only output the model before this " @@ -615,14 +617,21 @@ tensorflow::Status Export( flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); const uint8_t* buffer = builder.GetBufferPointer(); const ::tflite::Model* input_model = ::tflite::GetModel(buffer); - if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) != - kTfLiteOk) { + ::tflite::optimize::BufferType quantized_type; + if (params.quantize_weights == QuantizedBufferType::INT8) { + quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; + } else if (params.quantize_weights == QuantizedBufferType::FLOAT16) { + quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; + } else { + return tensorflow::errors::InvalidArgument( + "Quantized type not recognized"); + } + if (::tflite::optimize::QuantizeWeights(&q_builder, input_model, + quantized_type) != kTfLiteOk) { return tensorflow::errors::InvalidArgument( "Quantize weights transformation failed."); } WriteModelToString(q_builder, output_file_contents); - } else { - WriteModelToString(builder, output_file_contents); } return tensorflow::Status(); diff --git a/tensorflow/lite/toco/tflite/export.h b/tensorflow/lite/toco/tflite/export.h index 08d9c956778..3a6031d22b8 100644 --- a/tensorflow/lite/toco/tflite/export.h +++ b/tensorflow/lite/toco/tflite/export.h @@ -23,11 +23,13 @@ namespace toco { namespace tflite { +enum class QuantizedBufferType { NONE, INT8, FLOAT16 }; + // The parameters for exporting a TFLite model. struct ExportParams { bool allow_custom_ops = false; bool enable_select_tf_ops = false; - bool quantize_weights = false; + QuantizedBufferType quantize_weights = QuantizedBufferType::NONE; }; // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the @@ -47,7 +49,8 @@ inline void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, string* output_file_contents) { ExportParams params; params.allow_custom_ops = allow_custom_ops; - params.quantize_weights = quantize_weights; + params.quantize_weights = + quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE; auto status = Export(model, output_file_contents, params); if (!status.ok()) LOG(QFATAL) << status.error_message(); } @@ -60,7 +63,8 @@ inline void Export( const std::map>& ops_by_type) { ExportParams params; params.allow_custom_ops = allow_custom_ops; - params.quantize_weights = quantize_weights; + params.quantize_weights = + quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE; auto status = Export(model, output_file_contents, params, ops_by_type); if (!status.ok()) LOG(QFATAL) << status.error_message(); } diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index 85b824f4462..bbebf46a3b9 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -219,7 +219,7 @@ TEST_F(ExportTest, Export) { ExportParams params; params.allow_custom_ops = true; params.enable_select_tf_ops = false; - params.quantize_weights = false; + params.quantize_weights = QuantizedBufferType::NONE; EXPECT_THAT(ExportAndSummarizeOperators(params), ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp", @@ -366,7 +366,7 @@ class OpSetsTest : public ExportTest { import_all_ops_as_unsupported_ = true; params_.allow_custom_ops = false; params_.enable_select_tf_ops = false; - params_.quantize_weights = false; + params_.quantize_weights = QuantizedBufferType::NONE; for (const OpSet& i : sets) { switch (i) { diff --git a/tensorflow/lite/toco/toco_cmdline_flags.cc b/tensorflow/lite/toco/toco_cmdline_flags.cc index 7d525ae5583..c36b3de7748 100644 --- a/tensorflow/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/lite/toco/toco_cmdline_flags.cc @@ -158,6 +158,11 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.split_tflite_lstm_inputs.default_value(), "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. " "Ignored if the output format is not TFLite."), + Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(), + parsed_flags.quantize_to_float16.default_value(), + "Used in conjuction with post_training_quantize. Specifies that " + "the weights should be quantized to fp16 instead of the default " + "(int8)"), Flag("quantize_weights", parsed_flags.quantize_weights.bind(), parsed_flags.quantize_weights.default_value(), "Deprecated. Please use --post_training_quantize instead."), @@ -266,6 +271,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); + READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone); READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone); READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone); diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index cb015ba3d2a..50e9d332749 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; -import "tensorflow/lite/toco/types.proto"; package toco; +import "tensorflow/lite/toco/types.proto"; + // Supported I/O file formats. Some formats may be input-only or output-only. enum FileFormat { FILE_FORMAT_UNKNOWN = 0; @@ -37,7 +38,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 29. +// Next ID to use: 30. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -205,4 +206,10 @@ message TocoFlags { // `force_select_tf_ops` should always be used with `enable_select_tf_ops`. // WARNING: Experimental interface, subject to change optional bool force_select_tf_ops = 28 [default = false]; + + // Boolean indicating whether to convert float32 constant buffers to + // float16. This is typically done to reduce model size. Delegates may also + // wish to implement kernels on reduced precision floats for performance + // gains. + optional bool quantize_to_float16 = 29 [default = false]; } diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index c66ef1db915..3257fcdf2f8 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -20,17 +20,18 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/allocate_transient_arrays.h" #include "tensorflow/lite/toco/dump_graphviz.h" #include "tensorflow/lite/toco/export_tensorflow.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/import_tensorflow.h" +#include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/tflite/export.h" #include "tensorflow/lite/toco/tflite/import.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { namespace { @@ -449,8 +450,13 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, params.enable_select_tf_ops = toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops(); params.allow_custom_ops = allow_custom_ops; - params.quantize_weights = toco_flags.post_training_quantize(); - + if (toco_flags.post_training_quantize()) { + if (toco_flags.quantize_to_float16()) { + params.quantize_weights = tflite::QuantizedBufferType::FLOAT16; + } else { + params.quantize_weights = tflite::QuantizedBufferType::INT8; + } + } auto status = toco::tflite::Export(model, output_file_contents, params); if (!status.ok()) { LOG(ERROR) << status.error_message(); diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 22a473bf9dc..ce78fb701c2 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", + "//third_party/eigen3", "@com_google_absl//absl/memory", ], ) @@ -104,6 +105,7 @@ cc_library( ":quantization_utils", ":model_utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/container:flat_hash_map", "@flatbuffers", "//tensorflow/lite:framework", # TODO(suharshs): Move the relevant quantization utils to a non-internal location. diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 80c14fc0676..62af42fd6a8 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/optimize/quantization_utils.h" + +#include +#include + #include "absl/memory/memory.h" +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" -#include -#include - namespace tflite { namespace optimize { namespace utils { @@ -197,6 +199,42 @@ TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { return kTfLiteOk; } +TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor) { + if (model == nullptr || tensor == nullptr) { + return kTfLiteError; + } + + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return kTfLiteError; + } + + uint64_t num_elements; + TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements)); + + // Copy single byte buffer data to float vector to guard against misalignment. + std::vector float_vector(num_elements); + uint8_t* first = buffer->data.data(); + std::copy(first, first + buffer->data.size(), + reinterpret_cast(float_vector.data())); + + // Transform float data to float16. + std::vector quantized_buffer; + quantized_buffer.resize(num_elements); + std::transform( + float_vector.begin(), float_vector.end(), quantized_buffer.begin(), + [](float a) { return Eigen::half_impl::float_to_half_rtne(a); }); + + char* half_buffer = reinterpret_cast(quantized_buffer.data()); + model->buffers[tensor->buffer]->data.assign( + half_buffer, half_buffer + sizeof(Eigen::half) * num_elements); + + // Update the tensor type. + tensor->type = TensorType_FLOAT16; + + return kTfLiteOk; +} + TfLiteStatus AddQuantizationParams(const std::vector& scales, const std::vector& zero_point, int quantized_dimension, diff --git a/tensorflow/lite/tools/optimize/quantization_utils.h b/tensorflow/lite/tools/optimize/quantization_utils.h index 274862a536c..4cc67cfe40a 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_utils.h @@ -64,6 +64,9 @@ void SymmetricPerChannelQuantizeValues(const float* const input, // of the tensor. TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor); +// Quantizes tensor to float16. +TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor); + // Add quantization parameters. TfLiteStatus AddQuantizationParams(const std::vector& scales, const std::vector& zero_point, diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index 74813c180cb..c19d1879533 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -258,6 +258,38 @@ TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) { EXPECT_EQ(quant_buffer_size * 4, float_buffer_size); } +TEST(QuantizationUtilsTest, QuantizeFloat16) { + // Conv model has weights between 0 and 10. + // Quantize the weights tensor. + ASSERT_TRUE(g_test_model_dir != nullptr); + ASSERT_FALSE(g_test_model_dir->empty()); + auto test_model = ReadConvModel(); + ASSERT_TRUE(test_model); + auto readonly_model = test_model->GetModel(); + ASSERT_TRUE(readonly_model); + ASSERT_TRUE(readonly_model->subgraphs()); + ASSERT_GE(readonly_model->subgraphs()->size(), 1); + tflite::ModelT model; + readonly_model->UnPackTo(&model); + auto subgraph = model.subgraphs[0].get(); + auto conv_op = subgraph->operators.at(0).get(); + ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code, + BuiltinOperator_CONV_2D); + int32_t weights_tensor_idx = conv_op->inputs[1]; + TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); + + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32); + size_t float_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + + EXPECT_EQ(QuantizeTensorFloat16(&model, weights_tensor), kTfLiteOk); + + size_t quant_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT16); + EXPECT_EQ(quant_buffer_size * 2, float_buffer_size); +} + TEST(QuantizationUtilsTest, AddQuantizationParams) { // Create data. auto model = absl::make_unique(); diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index dda825393a5..89965e1190e 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "flatbuffers/flexbuffers.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/context.h" @@ -172,7 +173,7 @@ bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op, TfLiteStatus InsertQuantizableInputTensorsFromOperator( const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, - std::unordered_map* tensor_map) { + absl::flat_hash_map* tensor_map) { SubGraphT* subgraph = model->subgraphs.at(0).get(); const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); @@ -328,11 +329,11 @@ PassQuantizationAndGetConsumers( GetTensorConsumers(model, subgraph, output_tensor_idx)); } -TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, - const Model* input_model, - bool use_hybrid_evaluation, - uint64_t weights_min_num_elements, - const CustomOpMap& custom_op_map) { +TfLiteStatus QuantizeWeightsInt8(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + bool use_hybrid_evaluation, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map) { std::unique_ptr model; model.reset(input_model->UnPack()); @@ -345,15 +346,14 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, SubGraphT* subgraph = model->subgraphs.at(0).get(); - std::vector> new_operators; - std::unordered_map tensor_map; + absl::flat_hash_map tensor_map; for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator( model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map)); } - // The unordered_map ensures that we quantize each tensor exactly once. + // The hash map ensures that we quantize each tensor exactly once. // TODO(suharshs): This map key isn't sufficient when we support multiple // subgraphs. for (std::pair tensor_pair : tensor_map) { @@ -396,7 +396,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, } } - // Check that this tensor is an output tensor. + // Check if this tensor is an output tensor. int32_t output_index = -1; for (int32_t i = 0; i < subgraph->outputs.size(); ++i) { if (subgraph->outputs[i] == tensor_idx) { @@ -424,8 +424,6 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, dequantize_output_idx); - LOG(INFO) << "Creating Dequantize op with name " << dequant_name << "."; - // Update the op_input of all the ops that need the created dequantize // operation. int32_t min_op_idx = subgraph->operators.size(); @@ -455,6 +453,81 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, return kTfLiteOk; } +TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + // TODO(suharshs): When models support multiple subgraphs, add support. + if (model->subgraphs.size() != 1) { + LOG(ERROR) << "Quantize weights tool only supports tflite models with one " + "subgraph."; + return kTfLiteError; + } + + SubGraphT* subgraph = model->subgraphs.at(0).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + for (auto tensor_idx : op->inputs) { + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return kTfLiteError; + } + // Quantize tensors that have data to quantize. + bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); + if (tensor->type == TensorType_FLOAT32 && is_constant) { + tensor_map.insert({tensor_idx, tensor}); + } + } + } + + // The hash map ensures that we quantize each tensor exactly once. + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + TF_LITE_ENSURE_STATUS( + utils::QuantizeTensorFloat16(model.get(), tensor_pair.second)); + + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second; + std::vector dequant_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const string dequant_name = tensor->name + "_dequantize"; + utils::MakeTensor(dequant_name, tensor->shape, TensorType_FLOAT32, + &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + return kTfLiteOk; +} } // namespace namespace internal { @@ -465,8 +538,8 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, // By default we require that only weights with more than // kWeightsMinSizeDefault elements are quantized. CustomOpMap custom_op_map; - return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation, - weights_min_num_elements, custom_op_map); + return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, + weights_min_num_elements, custom_op_map); } } // namespace internal @@ -474,25 +547,31 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements) { CustomOpMap custom_op_map; - return QuantizeWeightsInternal(builder, input_model, true, - weights_min_num_elements, custom_op_map); + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map); } TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, - const Model* input_model) { - // By default we require that only weights with more than - // kWeightsMinSizeDefault elements are quantized. - CustomOpMap custom_op_map; - return QuantizeWeightsInternal(builder, input_model, true, + const Model* input_model, BufferType quant_type) { + switch (quant_type) { + case BufferType::QUANTIZED_INT8: { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, kWeightsMinNumElementsDefault, custom_op_map); + } + case BufferType::QUANTIZED_FLOAT16: + return QuantizeWeightsFloat16(builder, input_model); + } } TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map) { - return QuantizeWeightsInternal(builder, input_model, true, - weights_min_num_elements, custom_op_map); + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map); } } // namespace optimize diff --git a/tensorflow/lite/tools/optimize/quantize_weights.h b/tensorflow/lite/tools/optimize/quantize_weights.h index 62f9584011e..528614f0b7b 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.h +++ b/tensorflow/lite/tools/optimize/quantize_weights.h @@ -26,6 +26,9 @@ limitations under the License. namespace tflite { namespace optimize { +// Supported resulting types from quantization process. +enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 }; + // Quantizes input_model and populates the provided builder with the new model. // By default only weights tensors weight more than 1024 elements will be // quantized. @@ -33,8 +36,9 @@ namespace optimize { // A tflite::Model can be obtained from the builder with: // const uint8_t* buffer = builder->GetBufferPointer(); // tflite::Model* model = GetModel(buffer); -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, - const Model* input_model); +TfLiteStatus QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + BufferType quant_type = BufferType::QUANTIZED_INT8); // Same as above, but only weights with greater than or equal // weights_min_num_elements elements will be quantized. diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index e1cd031cc6d..c35259ef437 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -272,7 +272,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) { } else if (quant_tensor->name()->str() == "conv_bias") { EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); } else if (quant_tensor->buffer() != 0) { - // If its a non-bias constant tensor, is must be the weight. + // If it's a non-bias constant tensor, it must be the weight. EXPECT_EQ(quant_tensor->type(), TensorType_INT8); } else { EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); @@ -281,6 +281,65 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) { } } +TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + auto status = tflite::optimize::QuantizeWeights( + &builder, model_, BufferType::QUANTIZED_FLOAT16); + EXPECT_EQ(status, kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have two extra tensors from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 2); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (output_model->operator_codes()->Get(op_code_idx)->builtin_code() == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type FLOAT16. + // If the tensor is a bias, it should have type FLOAT16. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be FLOAT16, and all other tensors should + // be FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { LoadSharedWeightsModel(); flatbuffers::FlatBufferBuilder builder;