Support quantization to float16
PiperOrigin-RevId: 248790891
This commit is contained in:
parent
5c27f716a7
commit
3742faa90c
@ -395,6 +395,7 @@ cc_library(
|
||||
"//tensorflow/lite/kernels/internal:reference_base",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//third_party/eigen3",
|
||||
"@farmhash_archive//:farmhash",
|
||||
"@flatbuffers",
|
||||
],
|
||||
|
@ -12,13 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
@ -59,7 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpContext op_context(context, node);
|
||||
|
||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
|
||||
op_context.input->type == kTfLiteInt8);
|
||||
op_context.input->type == kTfLiteInt8 ||
|
||||
op_context.input->type == kTfLiteFloat16);
|
||||
|
||||
op_context.output->type = kTfLiteFloat32;
|
||||
// If the input tensor is constant, we can persist the dequantized value in
|
||||
@ -96,6 +101,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<float>(op_context.output));
|
||||
break;
|
||||
case kTfLiteFloat16: {
|
||||
const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
|
||||
GetTensorData<TfLiteFloat16>(op_context.input));
|
||||
reference_ops::Dequantize(GetTensorShape(op_context.input), half_data,
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<float>(op_context.output));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %d not supported.",
|
||||
op_context.input->type);
|
||||
|
@ -385,6 +385,7 @@ cc_library(
|
||||
":types",
|
||||
"@gemmlowp//:fixedpoint",
|
||||
"@gemmlowp//:profiler",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
] + select({
|
||||
@ -421,6 +422,7 @@ cc_library(
|
||||
":legacy_types",
|
||||
":tensor",
|
||||
":types",
|
||||
"//third_party/eigen3",
|
||||
"@gemmlowp",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "fixedpoint/fixedpoint.h"
|
||||
#include "profiling/instrumentation.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
@ -2491,6 +2492,15 @@ inline void Dequantize(const tflite::DequantizationParams& op_params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Dequantize(const RuntimeShape& input_shape,
|
||||
const Eigen::half* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
output_data[i] = Eigen::half_impl::half_to_float(input_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void AffineQuantize(const tflite::QuantizationParams& op_params,
|
||||
const RuntimeShape& input_shape,
|
||||
|
@ -172,6 +172,7 @@ struct ParsedTocoFlags {
|
||||
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
|
||||
Arg<bool> allow_custom_ops = Arg<bool>(false);
|
||||
Arg<bool> post_training_quantize = Arg<bool>(false);
|
||||
Arg<bool> quantize_to_float16 = Arg<bool>(false);
|
||||
// Deprecated flags
|
||||
Arg<bool> quantize_weights = Arg<bool>(false);
|
||||
Arg<string> input_type;
|
||||
|
@ -606,7 +606,9 @@ tensorflow::Status Export(
|
||||
builder.CreateVector(subgraphs), description, buffers);
|
||||
::tflite::FinishModelBuffer(builder, new_model_location);
|
||||
|
||||
if (params.quantize_weights) {
|
||||
if (params.quantize_weights == QuantizedBufferType::NONE) {
|
||||
WriteModelToString(builder, output_file_contents);
|
||||
} else {
|
||||
// Call the quantize_weights tool.
|
||||
LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
|
||||
"dump_graphviz will only output the model before this "
|
||||
@ -615,14 +617,21 @@ tensorflow::Status Export(
|
||||
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
||||
const uint8_t* buffer = builder.GetBufferPointer();
|
||||
const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
|
||||
if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) !=
|
||||
kTfLiteOk) {
|
||||
::tflite::optimize::BufferType quantized_type;
|
||||
if (params.quantize_weights == QuantizedBufferType::INT8) {
|
||||
quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
|
||||
} else if (params.quantize_weights == QuantizedBufferType::FLOAT16) {
|
||||
quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Quantized type not recognized");
|
||||
}
|
||||
if (::tflite::optimize::QuantizeWeights(&q_builder, input_model,
|
||||
quantized_type) != kTfLiteOk) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Quantize weights transformation failed.");
|
||||
}
|
||||
WriteModelToString(q_builder, output_file_contents);
|
||||
} else {
|
||||
WriteModelToString(builder, output_file_contents);
|
||||
}
|
||||
|
||||
return tensorflow::Status();
|
||||
|
@ -23,11 +23,13 @@ namespace toco {
|
||||
|
||||
namespace tflite {
|
||||
|
||||
enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
|
||||
|
||||
// The parameters for exporting a TFLite model.
|
||||
struct ExportParams {
|
||||
bool allow_custom_ops = false;
|
||||
bool enable_select_tf_ops = false;
|
||||
bool quantize_weights = false;
|
||||
QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
|
||||
};
|
||||
|
||||
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
|
||||
@ -47,7 +49,8 @@ inline void Export(const Model& model, bool allow_custom_ops,
|
||||
bool quantize_weights, string* output_file_contents) {
|
||||
ExportParams params;
|
||||
params.allow_custom_ops = allow_custom_ops;
|
||||
params.quantize_weights = quantize_weights;
|
||||
params.quantize_weights =
|
||||
quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
|
||||
auto status = Export(model, output_file_contents, params);
|
||||
if (!status.ok()) LOG(QFATAL) << status.error_message();
|
||||
}
|
||||
@ -60,7 +63,8 @@ inline void Export(
|
||||
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
|
||||
ExportParams params;
|
||||
params.allow_custom_ops = allow_custom_ops;
|
||||
params.quantize_weights = quantize_weights;
|
||||
params.quantize_weights =
|
||||
quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
|
||||
auto status = Export(model, output_file_contents, params, ops_by_type);
|
||||
if (!status.ok()) LOG(QFATAL) << status.error_message();
|
||||
}
|
||||
|
@ -219,7 +219,7 @@ TEST_F(ExportTest, Export) {
|
||||
ExportParams params;
|
||||
params.allow_custom_ops = true;
|
||||
params.enable_select_tf_ops = false;
|
||||
params.quantize_weights = false;
|
||||
params.quantize_weights = QuantizedBufferType::NONE;
|
||||
|
||||
EXPECT_THAT(ExportAndSummarizeOperators(params),
|
||||
ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp",
|
||||
@ -366,7 +366,7 @@ class OpSetsTest : public ExportTest {
|
||||
import_all_ops_as_unsupported_ = true;
|
||||
params_.allow_custom_ops = false;
|
||||
params_.enable_select_tf_ops = false;
|
||||
params_.quantize_weights = false;
|
||||
params_.quantize_weights = QuantizedBufferType::NONE;
|
||||
|
||||
for (const OpSet& i : sets) {
|
||||
switch (i) {
|
||||
|
@ -158,6 +158,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
|
||||
parsed_flags.split_tflite_lstm_inputs.default_value(),
|
||||
"Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
|
||||
"Ignored if the output format is not TFLite."),
|
||||
Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(),
|
||||
parsed_flags.quantize_to_float16.default_value(),
|
||||
"Used in conjuction with post_training_quantize. Specifies that "
|
||||
"the weights should be quantized to fp16 instead of the default "
|
||||
"(int8)"),
|
||||
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
|
||||
parsed_flags.quantize_weights.default_value(),
|
||||
"Deprecated. Please use --post_training_quantize instead."),
|
||||
@ -266,6 +271,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
|
||||
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone);
|
||||
READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone);
|
||||
|
@ -12,10 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto2";
|
||||
import "tensorflow/lite/toco/types.proto";
|
||||
|
||||
package toco;
|
||||
|
||||
import "tensorflow/lite/toco/types.proto";
|
||||
|
||||
// Supported I/O file formats. Some formats may be input-only or output-only.
|
||||
enum FileFormat {
|
||||
FILE_FORMAT_UNKNOWN = 0;
|
||||
@ -37,7 +38,7 @@ enum FileFormat {
|
||||
// of as properties of models, instead describing how models are to be
|
||||
// processed in the context of the present tooling job.
|
||||
//
|
||||
// Next ID to use: 29.
|
||||
// Next ID to use: 30.
|
||||
message TocoFlags {
|
||||
// Input file format
|
||||
optional FileFormat input_format = 1;
|
||||
@ -205,4 +206,10 @@ message TocoFlags {
|
||||
// `force_select_tf_ops` should always be used with `enable_select_tf_ops`.
|
||||
// WARNING: Experimental interface, subject to change
|
||||
optional bool force_select_tf_ops = 28 [default = false];
|
||||
|
||||
// Boolean indicating whether to convert float32 constant buffers to
|
||||
// float16. This is typically done to reduce model size. Delegates may also
|
||||
// wish to implement kernels on reduced precision floats for performance
|
||||
// gains.
|
||||
optional bool quantize_to_float16 = 29 [default = false];
|
||||
}
|
||||
|
@ -20,17 +20,18 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/toco/allocate_transient_arrays.h"
|
||||
#include "tensorflow/lite/toco/dump_graphviz.h"
|
||||
#include "tensorflow/lite/toco/export_tensorflow.h"
|
||||
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/tflite/export.h"
|
||||
#include "tensorflow/lite/toco/tflite/import.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/tooling_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace toco {
|
||||
namespace {
|
||||
@ -449,8 +450,13 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
||||
params.enable_select_tf_ops =
|
||||
toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
|
||||
params.allow_custom_ops = allow_custom_ops;
|
||||
params.quantize_weights = toco_flags.post_training_quantize();
|
||||
|
||||
if (toco_flags.post_training_quantize()) {
|
||||
if (toco_flags.quantize_to_float16()) {
|
||||
params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
|
||||
} else {
|
||||
params.quantize_weights = tflite::QuantizedBufferType::INT8;
|
||||
}
|
||||
}
|
||||
auto status = toco::tflite::Export(model, output_file_contents, params);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status.error_message();
|
||||
|
@ -22,6 +22,7 @@ cc_library(
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -104,6 +105,7 @@ cc_library(
|
||||
":quantization_utils",
|
||||
":model_utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@flatbuffers",
|
||||
"//tensorflow/lite:framework",
|
||||
# TODO(suharshs): Move the relevant quantization utils to a non-internal location.
|
||||
|
@ -13,15 +13,17 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
namespace utils {
|
||||
@ -197,6 +199,42 @@ TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor) {
|
||||
if (model == nullptr || tensor == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
BufferT* buffer = model->buffers[tensor->buffer].get();
|
||||
if (buffer == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
uint64_t num_elements;
|
||||
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
|
||||
|
||||
// Copy single byte buffer data to float vector to guard against misalignment.
|
||||
std::vector<float> float_vector(num_elements);
|
||||
uint8_t* first = buffer->data.data();
|
||||
std::copy(first, first + buffer->data.size(),
|
||||
reinterpret_cast<uint8_t*>(float_vector.data()));
|
||||
|
||||
// Transform float data to float16.
|
||||
std::vector<Eigen::half> quantized_buffer;
|
||||
quantized_buffer.resize(num_elements);
|
||||
std::transform(
|
||||
float_vector.begin(), float_vector.end(), quantized_buffer.begin(),
|
||||
[](float a) { return Eigen::half_impl::float_to_half_rtne(a); });
|
||||
|
||||
char* half_buffer = reinterpret_cast<char*>(quantized_buffer.data());
|
||||
model->buffers[tensor->buffer]->data.assign(
|
||||
half_buffer, half_buffer + sizeof(Eigen::half) * num_elements);
|
||||
|
||||
// Update the tensor type.
|
||||
tensor->type = TensorType_FLOAT16;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus AddQuantizationParams(const std::vector<float>& scales,
|
||||
const std::vector<int64_t>& zero_point,
|
||||
int quantized_dimension,
|
||||
|
@ -64,6 +64,9 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
|
||||
// of the tensor.
|
||||
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor);
|
||||
|
||||
// Quantizes tensor to float16.
|
||||
TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor);
|
||||
|
||||
// Add quantization parameters.
|
||||
TfLiteStatus AddQuantizationParams(const std::vector<float>& scales,
|
||||
const std::vector<int64_t>& zero_point,
|
||||
|
@ -258,6 +258,38 @@ TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) {
|
||||
EXPECT_EQ(quant_buffer_size * 4, float_buffer_size);
|
||||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, QuantizeFloat16) {
|
||||
// Conv model has weights between 0 and 10.
|
||||
// Quantize the weights tensor.
|
||||
ASSERT_TRUE(g_test_model_dir != nullptr);
|
||||
ASSERT_FALSE(g_test_model_dir->empty());
|
||||
auto test_model = ReadConvModel();
|
||||
ASSERT_TRUE(test_model);
|
||||
auto readonly_model = test_model->GetModel();
|
||||
ASSERT_TRUE(readonly_model);
|
||||
ASSERT_TRUE(readonly_model->subgraphs());
|
||||
ASSERT_GE(readonly_model->subgraphs()->size(), 1);
|
||||
tflite::ModelT model;
|
||||
readonly_model->UnPackTo(&model);
|
||||
auto subgraph = model.subgraphs[0].get();
|
||||
auto conv_op = subgraph->operators.at(0).get();
|
||||
ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code,
|
||||
BuiltinOperator_CONV_2D);
|
||||
int32_t weights_tensor_idx = conv_op->inputs[1];
|
||||
TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get();
|
||||
|
||||
EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32);
|
||||
size_t float_buffer_size =
|
||||
model.buffers.at(weights_tensor->buffer)->data.size();
|
||||
|
||||
EXPECT_EQ(QuantizeTensorFloat16(&model, weights_tensor), kTfLiteOk);
|
||||
|
||||
size_t quant_buffer_size =
|
||||
model.buffers.at(weights_tensor->buffer)->data.size();
|
||||
EXPECT_EQ(weights_tensor->type, TensorType_FLOAT16);
|
||||
EXPECT_EQ(quant_buffer_size * 2, float_buffer_size);
|
||||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, AddQuantizationParams) {
|
||||
// Create data.
|
||||
auto model = absl::make_unique<ModelT>();
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
@ -172,7 +173,7 @@ bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op,
|
||||
TfLiteStatus InsertQuantizableInputTensorsFromOperator(
|
||||
const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
|
||||
const CustomOpMap& custom_op_map,
|
||||
std::unordered_map<int32_t, TensorT*>* tensor_map) {
|
||||
absl::flat_hash_map<int32_t, TensorT*>* tensor_map) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(0).get();
|
||||
const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
|
||||
|
||||
@ -328,11 +329,11 @@ PassQuantizationAndGetConsumers(
|
||||
GetTensorConsumers(model, subgraph, output_tensor_idx));
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model,
|
||||
bool use_hybrid_evaluation,
|
||||
uint64_t weights_min_num_elements,
|
||||
const CustomOpMap& custom_op_map) {
|
||||
TfLiteStatus QuantizeWeightsInt8(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model,
|
||||
bool use_hybrid_evaluation,
|
||||
uint64_t weights_min_num_elements,
|
||||
const CustomOpMap& custom_op_map) {
|
||||
std::unique_ptr<ModelT> model;
|
||||
model.reset(input_model->UnPack());
|
||||
|
||||
@ -345,15 +346,14 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
|
||||
SubGraphT* subgraph = model->subgraphs.at(0).get();
|
||||
|
||||
std::vector<std::unique_ptr<OperatorT>> new_operators;
|
||||
std::unordered_map<int32_t, TensorT*> tensor_map;
|
||||
absl::flat_hash_map<int32_t, TensorT*> tensor_map;
|
||||
for (int i = 0; i < subgraph->operators.size(); ++i) {
|
||||
OperatorT* op = subgraph->operators[i].get();
|
||||
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
|
||||
model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map));
|
||||
}
|
||||
|
||||
// The unordered_map ensures that we quantize each tensor exactly once.
|
||||
// The hash map ensures that we quantize each tensor exactly once.
|
||||
// TODO(suharshs): This map key isn't sufficient when we support multiple
|
||||
// subgraphs.
|
||||
for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
|
||||
@ -396,7 +396,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
}
|
||||
}
|
||||
|
||||
// Check that this tensor is an output tensor.
|
||||
// Check if this tensor is an output tensor.
|
||||
int32_t output_index = -1;
|
||||
for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
|
||||
if (subgraph->outputs[i] == tensor_idx) {
|
||||
@ -424,8 +424,6 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
|
||||
dequantize_output_idx);
|
||||
|
||||
LOG(INFO) << "Creating Dequantize op with name " << dequant_name << ".";
|
||||
|
||||
// Update the op_input of all the ops that need the created dequantize
|
||||
// operation.
|
||||
int32_t min_op_idx = subgraph->operators.size();
|
||||
@ -455,6 +453,81 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model) {
|
||||
std::unique_ptr<ModelT> model;
|
||||
model.reset(input_model->UnPack());
|
||||
|
||||
// TODO(suharshs): When models support multiple subgraphs, add support.
|
||||
if (model->subgraphs.size() != 1) {
|
||||
LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
|
||||
"subgraph.";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
SubGraphT* subgraph = model->subgraphs.at(0).get();
|
||||
|
||||
absl::flat_hash_map<int32_t, TensorT*> tensor_map;
|
||||
for (int i = 0; i < subgraph->operators.size(); ++i) {
|
||||
OperatorT* op = subgraph->operators[i].get();
|
||||
for (auto tensor_idx : op->inputs) {
|
||||
TensorT* tensor = subgraph->tensors[tensor_idx].get();
|
||||
BufferT* buffer = model->buffers[tensor->buffer].get();
|
||||
if (buffer == nullptr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
// Quantize tensors that have data to quantize.
|
||||
bool is_constant = !model->buffers[tensor->buffer].get()->data.empty();
|
||||
if (tensor->type == TensorType_FLOAT32 && is_constant) {
|
||||
tensor_map.insert({tensor_idx, tensor});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The hash map ensures that we quantize each tensor exactly once.
|
||||
for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
|
||||
// Quantize the tensor.
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
utils::QuantizeTensorFloat16(model.get(), tensor_pair.second));
|
||||
|
||||
int32_t tensor_idx = tensor_pair.first;
|
||||
TensorT* tensor = tensor_pair.second;
|
||||
std::vector<ConsumerOpInfo> dequant_op_infos =
|
||||
GetTensorConsumers(model.get(), subgraph, tensor_idx);
|
||||
|
||||
// Create a new tensor to be the output of the dequantize op.
|
||||
std::unique_ptr<TensorT> dequantize_output;
|
||||
const string dequant_name = tensor->name + "_dequantize";
|
||||
utils::MakeTensor(dequant_name, tensor->shape, TensorType_FLOAT32,
|
||||
&dequantize_output);
|
||||
const int32_t dequantize_output_idx = subgraph->tensors.size();
|
||||
subgraph->tensors.push_back(std::move(dequantize_output));
|
||||
|
||||
// Create the Dequantize operation.
|
||||
std::unique_ptr<OperatorT> dequantize_op;
|
||||
utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
|
||||
dequantize_output_idx);
|
||||
|
||||
// Update the op_input of all the ops that need the created dequantize
|
||||
// operation.
|
||||
int32_t min_op_idx = subgraph->operators.size();
|
||||
for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
|
||||
dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
|
||||
dequantize_output_idx;
|
||||
min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
|
||||
}
|
||||
|
||||
// Insert the newly created Dequantize operation before the earliest
|
||||
// consumer, since TFLite requires operators to be topo-sorted.
|
||||
subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
|
||||
std::move(dequantize_op));
|
||||
}
|
||||
|
||||
flatbuffers::Offset<Model> output_model_location =
|
||||
Model::Pack(*builder, model.get());
|
||||
FinishModelBuffer(*builder, output_model_location);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
@ -465,8 +538,8 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
|
||||
// By default we require that only weights with more than
|
||||
// kWeightsMinSizeDefault elements are quantized.
|
||||
CustomOpMap custom_op_map;
|
||||
return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
|
||||
weights_min_num_elements, custom_op_map);
|
||||
return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation,
|
||||
weights_min_num_elements, custom_op_map);
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
@ -474,25 +547,31 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model,
|
||||
uint64_t weights_min_num_elements) {
|
||||
CustomOpMap custom_op_map;
|
||||
return QuantizeWeightsInternal(builder, input_model, true,
|
||||
weights_min_num_elements, custom_op_map);
|
||||
return QuantizeWeightsInt8(builder, input_model, true,
|
||||
weights_min_num_elements, custom_op_map);
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model) {
|
||||
// By default we require that only weights with more than
|
||||
// kWeightsMinSizeDefault elements are quantized.
|
||||
CustomOpMap custom_op_map;
|
||||
return QuantizeWeightsInternal(builder, input_model, true,
|
||||
const Model* input_model, BufferType quant_type) {
|
||||
switch (quant_type) {
|
||||
case BufferType::QUANTIZED_INT8: {
|
||||
// By default we require that only weights with more than
|
||||
// kWeightsMinSizeDefault elements are quantized.
|
||||
CustomOpMap custom_op_map;
|
||||
return QuantizeWeightsInt8(builder, input_model, true,
|
||||
kWeightsMinNumElementsDefault, custom_op_map);
|
||||
}
|
||||
case BufferType::QUANTIZED_FLOAT16:
|
||||
return QuantizeWeightsFloat16(builder, input_model);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model,
|
||||
uint64_t weights_min_num_elements,
|
||||
const CustomOpMap& custom_op_map) {
|
||||
return QuantizeWeightsInternal(builder, input_model, true,
|
||||
weights_min_num_elements, custom_op_map);
|
||||
return QuantizeWeightsInt8(builder, input_model, true,
|
||||
weights_min_num_elements, custom_op_map);
|
||||
}
|
||||
|
||||
} // namespace optimize
|
||||
|
@ -26,6 +26,9 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
|
||||
// Supported resulting types from quantization process.
|
||||
enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 };
|
||||
|
||||
// Quantizes input_model and populates the provided builder with the new model.
|
||||
// By default only weights tensors weight more than 1024 elements will be
|
||||
// quantized.
|
||||
@ -33,8 +36,9 @@ namespace optimize {
|
||||
// A tflite::Model can be obtained from the builder with:
|
||||
// const uint8_t* buffer = builder->GetBufferPointer();
|
||||
// tflite::Model* model = GetModel(buffer);
|
||||
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
|
||||
const Model* input_model);
|
||||
TfLiteStatus QuantizeWeights(
|
||||
flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
|
||||
BufferType quant_type = BufferType::QUANTIZED_INT8);
|
||||
|
||||
// Same as above, but only weights with greater than or equal
|
||||
// weights_min_num_elements elements will be quantized.
|
||||
|
@ -272,7 +272,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) {
|
||||
} else if (quant_tensor->name()->str() == "conv_bias") {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
|
||||
} else if (quant_tensor->buffer() != 0) {
|
||||
// If its a non-bias constant tensor, is must be the weight.
|
||||
// If it's a non-bias constant tensor, it must be the weight.
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
|
||||
} else {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
|
||||
@ -281,6 +281,65 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) {
|
||||
LoadBasicModel();
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = tflite::optimize::QuantizeWeights(
|
||||
&builder, model_, BufferType::QUANTIZED_FLOAT16);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
const uint8_t* buffer = builder.GetBufferPointer();
|
||||
const Model* output_model = GetModel(buffer);
|
||||
ASSERT_TRUE(output_model);
|
||||
|
||||
ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
|
||||
for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
|
||||
++subgraph_idx) {
|
||||
const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
|
||||
const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
|
||||
// The output graph should have two extra tensors from the added dequantize
|
||||
// op.
|
||||
ASSERT_EQ(quantized_graph->tensors()->size(),
|
||||
float_graph->tensors()->size() + 2);
|
||||
// Check that a dequantize op exists.
|
||||
int32_t dequant_input_idx = -1;
|
||||
int32_t dequant_output_idx = -1;
|
||||
for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
|
||||
const auto op = quantized_graph->operators()->Get(i);
|
||||
const uint32_t op_code_idx = op->opcode_index();
|
||||
if (output_model->operator_codes()->Get(op_code_idx)->builtin_code() ==
|
||||
BuiltinOperator_DEQUANTIZE) {
|
||||
dequant_input_idx = op->inputs()->Get(0);
|
||||
dequant_output_idx = op->outputs()->Get(0);
|
||||
}
|
||||
}
|
||||
ASSERT_GT(dequant_input_idx, -1);
|
||||
ASSERT_GT(dequant_output_idx, -1);
|
||||
for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
|
||||
const auto quant_tensor = quantized_graph->tensors()->Get(i);
|
||||
// If the tensor is a weight, it should have type FLOAT16.
|
||||
// If the tensor is a bias, it should have type FLOAT16.
|
||||
// If the tensor is an input or output it should have type FLOAT32.
|
||||
// The input to dequantize should be FLOAT16, and all other tensors should
|
||||
// be FLOAT32.
|
||||
if (i == dequant_input_idx) {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
|
||||
} else if (i == dequant_output_idx) {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
|
||||
} else if (IsModelInputOrOutput(output_model, i)) {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
|
||||
} else if (quant_tensor->name()->str() == "conv_bias") {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
|
||||
} else if (quant_tensor->buffer() != 0) {
|
||||
// If it's a non-bias constant tensor, it must be the weight.
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
|
||||
} else {
|
||||
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) {
|
||||
LoadSharedWeightsModel();
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
|
Loading…
Reference in New Issue
Block a user