Support quantization to float16

PiperOrigin-RevId: 248790891
This commit is contained in:
A. Unique TensorFlower 2019-05-17 14:43:40 -07:00 committed by TensorFlower Gardener
parent 5c27f716a7
commit 3742faa90c
18 changed files with 322 additions and 46 deletions

View File

@ -395,6 +395,7 @@ cc_library(
"//tensorflow/lite/kernels/internal:reference_base",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
"//third_party/eigen3",
"@farmhash_archive//:farmhash",
"@flatbuffers",
],

View File

@ -12,13 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
#include <string.h>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
@ -59,7 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
op_context.input->type == kTfLiteInt8);
op_context.input->type == kTfLiteInt8 ||
op_context.input->type == kTfLiteFloat16);
op_context.output->type = kTfLiteFloat32;
// If the input tensor is constant, we can persist the dequantized value in
@ -96,6 +101,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(op_context.output),
GetTensorData<float>(op_context.output));
break;
case kTfLiteFloat16: {
const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
GetTensorData<TfLiteFloat16>(op_context.input));
reference_ops::Dequantize(GetTensorShape(op_context.input), half_data,
GetTensorShape(op_context.output),
GetTensorData<float>(op_context.output));
break;
}
default:
context->ReportError(context, "Type %d not supported.",
op_context.input->type);

View File

@ -385,6 +385,7 @@ cc_library(
":types",
"@gemmlowp//:fixedpoint",
"@gemmlowp//:profiler",
"//third_party/eigen3",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:op_macros",
] + select({
@ -421,6 +422,7 @@ cc_library(
":legacy_types",
":tensor",
":types",
"//third_party/eigen3",
"@gemmlowp",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:op_macros",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include <memory>
#include <type_traits>
#include "third_party/eigen3/Eigen/Core"
#include "fixedpoint/fixedpoint.h"
#include "profiling/instrumentation.h"
#include "tensorflow/lite/c/c_api_internal.h"
@ -2491,6 +2492,15 @@ inline void Dequantize(const tflite::DequantizationParams& op_params,
}
}
inline void Dequantize(const RuntimeShape& input_shape,
const Eigen::half* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
output_data[i] = Eigen::half_impl::half_to_float(input_data[i]);
}
}
template <typename T>
inline void AffineQuantize(const tflite::QuantizationParams& op_params,
const RuntimeShape& input_shape,

View File

@ -172,6 +172,7 @@ struct ParsedTocoFlags {
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
Arg<bool> post_training_quantize = Arg<bool>(false);
Arg<bool> quantize_to_float16 = Arg<bool>(false);
// Deprecated flags
Arg<bool> quantize_weights = Arg<bool>(false);
Arg<string> input_type;

View File

@ -606,7 +606,9 @@ tensorflow::Status Export(
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
if (params.quantize_weights) {
if (params.quantize_weights == QuantizedBufferType::NONE) {
WriteModelToString(builder, output_file_contents);
} else {
// Call the quantize_weights tool.
LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
"dump_graphviz will only output the model before this "
@ -615,14 +617,21 @@ tensorflow::Status Export(
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
const uint8_t* buffer = builder.GetBufferPointer();
const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) !=
kTfLiteOk) {
::tflite::optimize::BufferType quantized_type;
if (params.quantize_weights == QuantizedBufferType::INT8) {
quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8;
} else if (params.quantize_weights == QuantizedBufferType::FLOAT16) {
quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
} else {
return tensorflow::errors::InvalidArgument(
"Quantized type not recognized");
}
if (::tflite::optimize::QuantizeWeights(&q_builder, input_model,
quantized_type) != kTfLiteOk) {
return tensorflow::errors::InvalidArgument(
"Quantize weights transformation failed.");
}
WriteModelToString(q_builder, output_file_contents);
} else {
WriteModelToString(builder, output_file_contents);
}
return tensorflow::Status();

View File

@ -23,11 +23,13 @@ namespace toco {
namespace tflite {
enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
// The parameters for exporting a TFLite model.
struct ExportParams {
bool allow_custom_ops = false;
bool enable_select_tf_ops = false;
bool quantize_weights = false;
QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
};
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
@ -47,7 +49,8 @@ inline void Export(const Model& model, bool allow_custom_ops,
bool quantize_weights, string* output_file_contents) {
ExportParams params;
params.allow_custom_ops = allow_custom_ops;
params.quantize_weights = quantize_weights;
params.quantize_weights =
quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
auto status = Export(model, output_file_contents, params);
if (!status.ok()) LOG(QFATAL) << status.error_message();
}
@ -60,7 +63,8 @@ inline void Export(
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
ExportParams params;
params.allow_custom_ops = allow_custom_ops;
params.quantize_weights = quantize_weights;
params.quantize_weights =
quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
auto status = Export(model, output_file_contents, params, ops_by_type);
if (!status.ok()) LOG(QFATAL) << status.error_message();
}

View File

@ -219,7 +219,7 @@ TEST_F(ExportTest, Export) {
ExportParams params;
params.allow_custom_ops = true;
params.enable_select_tf_ops = false;
params.quantize_weights = false;
params.quantize_weights = QuantizedBufferType::NONE;
EXPECT_THAT(ExportAndSummarizeOperators(params),
ElementsAre("builtin:ADD", "builtin:CONV_2D", "custom:MyCrazyOp",
@ -366,7 +366,7 @@ class OpSetsTest : public ExportTest {
import_all_ops_as_unsupported_ = true;
params_.allow_custom_ops = false;
params_.enable_select_tf_ops = false;
params_.quantize_weights = false;
params_.quantize_weights = QuantizedBufferType::NONE;
for (const OpSet& i : sets) {
switch (i) {

View File

@ -158,6 +158,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.split_tflite_lstm_inputs.default_value(),
"Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
"Ignored if the output format is not TFLite."),
Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(),
parsed_flags.quantize_to_float16.default_value(),
"Used in conjuction with post_training_quantize. Specifies that "
"the weights should be quantized to fp16 instead of the default "
"(int8)"),
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
parsed_flags.quantize_weights.default_value(),
"Deprecated. Please use --post_training_quantize instead."),
@ -266,6 +271,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone);
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone);
READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone);

View File

@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
import "tensorflow/lite/toco/types.proto";
package toco;
import "tensorflow/lite/toco/types.proto";
// Supported I/O file formats. Some formats may be input-only or output-only.
enum FileFormat {
FILE_FORMAT_UNKNOWN = 0;
@ -37,7 +38,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
// Next ID to use: 29.
// Next ID to use: 30.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@ -205,4 +206,10 @@ message TocoFlags {
// `force_select_tf_ops` should always be used with `enable_select_tf_ops`.
// WARNING: Experimental interface, subject to change
optional bool force_select_tf_ops = 28 [default = false];
// Boolean indicating whether to convert float32 constant buffers to
// float16. This is typically done to reduce model size. Delegates may also
// wish to implement kernels on reduced precision floats for performance
// gains.
optional bool quantize_to_float16 = 29 [default = false];
}

View File

@ -20,17 +20,18 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/allocate_transient_arrays.h"
#include "tensorflow/lite/toco/dump_graphviz.h"
#include "tensorflow/lite/toco/export_tensorflow.h"
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/lite/toco/import_tensorflow.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/tflite/export.h"
#include "tensorflow/lite/toco/tflite/import.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
namespace {
@ -449,8 +450,13 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
params.enable_select_tf_ops =
toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
params.allow_custom_ops = allow_custom_ops;
params.quantize_weights = toco_flags.post_training_quantize();
if (toco_flags.post_training_quantize()) {
if (toco_flags.quantize_to_float16()) {
params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
} else {
params.quantize_weights = tflite::QuantizedBufferType::INT8;
}
}
auto status = toco::tflite::Export(model, output_file_contents, params);
if (!status.ok()) {
LOG(ERROR) << status.error_message();

View File

@ -22,6 +22,7 @@ cc_library(
"//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/kernels/internal:types",
"//tensorflow/lite/schema:schema_fbs",
"//third_party/eigen3",
"@com_google_absl//absl/memory",
],
)
@ -104,6 +105,7 @@ cc_library(
":quantization_utils",
":model_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/container:flat_hash_map",
"@flatbuffers",
"//tensorflow/lite:framework",
# TODO(suharshs): Move the relevant quantization utils to a non-internal location.

View File

@ -13,15 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/optimize/quantization_utils.h"
#include <cmath>
#include <cstdint>
#include "absl/memory/memory.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include <cmath>
#include <cstdint>
namespace tflite {
namespace optimize {
namespace utils {
@ -197,6 +199,42 @@ TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
return kTfLiteOk;
}
TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor) {
if (model == nullptr || tensor == nullptr) {
return kTfLiteError;
}
BufferT* buffer = model->buffers[tensor->buffer].get();
if (buffer == nullptr) {
return kTfLiteError;
}
uint64_t num_elements;
TF_LITE_ENSURE_STATUS(NumElements(*tensor, &num_elements));
// Copy single byte buffer data to float vector to guard against misalignment.
std::vector<float> float_vector(num_elements);
uint8_t* first = buffer->data.data();
std::copy(first, first + buffer->data.size(),
reinterpret_cast<uint8_t*>(float_vector.data()));
// Transform float data to float16.
std::vector<Eigen::half> quantized_buffer;
quantized_buffer.resize(num_elements);
std::transform(
float_vector.begin(), float_vector.end(), quantized_buffer.begin(),
[](float a) { return Eigen::half_impl::float_to_half_rtne(a); });
char* half_buffer = reinterpret_cast<char*>(quantized_buffer.data());
model->buffers[tensor->buffer]->data.assign(
half_buffer, half_buffer + sizeof(Eigen::half) * num_elements);
// Update the tensor type.
tensor->type = TensorType_FLOAT16;
return kTfLiteOk;
}
TfLiteStatus AddQuantizationParams(const std::vector<float>& scales,
const std::vector<int64_t>& zero_point,
int quantized_dimension,

View File

@ -64,6 +64,9 @@ void SymmetricPerChannelQuantizeValues(const float* const input,
// of the tensor.
TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor);
// Quantizes tensor to float16.
TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor);
// Add quantization parameters.
TfLiteStatus AddQuantizationParams(const std::vector<float>& scales,
const std::vector<int64_t>& zero_point,

View File

@ -258,6 +258,38 @@ TEST(QuantizationUtilsTest, SymmetricQuantizeTensor) {
EXPECT_EQ(quant_buffer_size * 4, float_buffer_size);
}
TEST(QuantizationUtilsTest, QuantizeFloat16) {
// Conv model has weights between 0 and 10.
// Quantize the weights tensor.
ASSERT_TRUE(g_test_model_dir != nullptr);
ASSERT_FALSE(g_test_model_dir->empty());
auto test_model = ReadConvModel();
ASSERT_TRUE(test_model);
auto readonly_model = test_model->GetModel();
ASSERT_TRUE(readonly_model);
ASSERT_TRUE(readonly_model->subgraphs());
ASSERT_GE(readonly_model->subgraphs()->size(), 1);
tflite::ModelT model;
readonly_model->UnPackTo(&model);
auto subgraph = model.subgraphs[0].get();
auto conv_op = subgraph->operators.at(0).get();
ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code,
BuiltinOperator_CONV_2D);
int32_t weights_tensor_idx = conv_op->inputs[1];
TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get();
EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32);
size_t float_buffer_size =
model.buffers.at(weights_tensor->buffer)->data.size();
EXPECT_EQ(QuantizeTensorFloat16(&model, weights_tensor), kTfLiteOk);
size_t quant_buffer_size =
model.buffers.at(weights_tensor->buffer)->data.size();
EXPECT_EQ(weights_tensor->type, TensorType_FLOAT16);
EXPECT_EQ(quant_buffer_size * 2, float_buffer_size);
}
TEST(QuantizationUtilsTest, AddQuantizationParams) {
// Create data.
auto model = absl::make_unique<ModelT>();

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "flatbuffers/flexbuffers.h"
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/context.h"
@ -172,7 +173,7 @@ bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op,
TfLiteStatus InsertQuantizableInputTensorsFromOperator(
const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map,
std::unordered_map<int32_t, TensorT*>* tensor_map) {
absl::flat_hash_map<int32_t, TensorT*>* tensor_map) {
SubGraphT* subgraph = model->subgraphs.at(0).get();
const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
@ -328,11 +329,11 @@ PassQuantizationAndGetConsumers(
GetTensorConsumers(model, subgraph, output_tensor_idx));
}
TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model,
bool use_hybrid_evaluation,
uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map) {
TfLiteStatus QuantizeWeightsInt8(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model,
bool use_hybrid_evaluation,
uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map) {
std::unique_ptr<ModelT> model;
model.reset(input_model->UnPack());
@ -345,15 +346,14 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
SubGraphT* subgraph = model->subgraphs.at(0).get();
std::vector<std::unique_ptr<OperatorT>> new_operators;
std::unordered_map<int32_t, TensorT*> tensor_map;
absl::flat_hash_map<int32_t, TensorT*> tensor_map;
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
TF_LITE_ENSURE_STATUS(InsertQuantizableInputTensorsFromOperator(
model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map));
}
// The unordered_map ensures that we quantize each tensor exactly once.
// The hash map ensures that we quantize each tensor exactly once.
// TODO(suharshs): This map key isn't sufficient when we support multiple
// subgraphs.
for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
@ -396,7 +396,7 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
}
}
// Check that this tensor is an output tensor.
// Check if this tensor is an output tensor.
int32_t output_index = -1;
for (int32_t i = 0; i < subgraph->outputs.size(); ++i) {
if (subgraph->outputs[i] == tensor_idx) {
@ -424,8 +424,6 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
dequantize_output_idx);
LOG(INFO) << "Creating Dequantize op with name " << dequant_name << ".";
// Update the op_input of all the ops that need the created dequantize
// operation.
int32_t min_op_idx = subgraph->operators.size();
@ -455,6 +453,81 @@ TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
return kTfLiteOk;
}
TfLiteStatus QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model) {
std::unique_ptr<ModelT> model;
model.reset(input_model->UnPack());
// TODO(suharshs): When models support multiple subgraphs, add support.
if (model->subgraphs.size() != 1) {
LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
"subgraph.";
return kTfLiteError;
}
SubGraphT* subgraph = model->subgraphs.at(0).get();
absl::flat_hash_map<int32_t, TensorT*> tensor_map;
for (int i = 0; i < subgraph->operators.size(); ++i) {
OperatorT* op = subgraph->operators[i].get();
for (auto tensor_idx : op->inputs) {
TensorT* tensor = subgraph->tensors[tensor_idx].get();
BufferT* buffer = model->buffers[tensor->buffer].get();
if (buffer == nullptr) {
return kTfLiteError;
}
// Quantize tensors that have data to quantize.
bool is_constant = !model->buffers[tensor->buffer].get()->data.empty();
if (tensor->type == TensorType_FLOAT32 && is_constant) {
tensor_map.insert({tensor_idx, tensor});
}
}
}
// The hash map ensures that we quantize each tensor exactly once.
for (std::pair<int32_t, TensorT*> tensor_pair : tensor_map) {
// Quantize the tensor.
TF_LITE_ENSURE_STATUS(
utils::QuantizeTensorFloat16(model.get(), tensor_pair.second));
int32_t tensor_idx = tensor_pair.first;
TensorT* tensor = tensor_pair.second;
std::vector<ConsumerOpInfo> dequant_op_infos =
GetTensorConsumers(model.get(), subgraph, tensor_idx);
// Create a new tensor to be the output of the dequantize op.
std::unique_ptr<TensorT> dequantize_output;
const string dequant_name = tensor->name + "_dequantize";
utils::MakeTensor(dequant_name, tensor->shape, TensorType_FLOAT32,
&dequantize_output);
const int32_t dequantize_output_idx = subgraph->tensors.size();
subgraph->tensors.push_back(std::move(dequantize_output));
// Create the Dequantize operation.
std::unique_ptr<OperatorT> dequantize_op;
utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
dequantize_output_idx);
// Update the op_input of all the ops that need the created dequantize
// operation.
int32_t min_op_idx = subgraph->operators.size();
for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) {
dequant_op_info.op->inputs[dequant_op_info.op_input_idx] =
dequantize_output_idx;
min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx);
}
// Insert the newly created Dequantize operation before the earliest
// consumer, since TFLite requires operators to be topo-sorted.
subgraph->operators.insert(subgraph->operators.begin() + min_op_idx,
std::move(dequantize_op));
}
flatbuffers::Offset<Model> output_model_location =
Model::Pack(*builder, model.get());
FinishModelBuffer(*builder, output_model_location);
return kTfLiteOk;
}
} // namespace
namespace internal {
@ -465,8 +538,8 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
// By default we require that only weights with more than
// kWeightsMinSizeDefault elements are quantized.
CustomOpMap custom_op_map;
return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
weights_min_num_elements, custom_op_map);
return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation,
weights_min_num_elements, custom_op_map);
}
} // namespace internal
@ -474,25 +547,31 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model,
uint64_t weights_min_num_elements) {
CustomOpMap custom_op_map;
return QuantizeWeightsInternal(builder, input_model, true,
weights_min_num_elements, custom_op_map);
return QuantizeWeightsInt8(builder, input_model, true,
weights_min_num_elements, custom_op_map);
}
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model) {
// By default we require that only weights with more than
// kWeightsMinSizeDefault elements are quantized.
CustomOpMap custom_op_map;
return QuantizeWeightsInternal(builder, input_model, true,
const Model* input_model, BufferType quant_type) {
switch (quant_type) {
case BufferType::QUANTIZED_INT8: {
// By default we require that only weights with more than
// kWeightsMinSizeDefault elements are quantized.
CustomOpMap custom_op_map;
return QuantizeWeightsInt8(builder, input_model, true,
kWeightsMinNumElementsDefault, custom_op_map);
}
case BufferType::QUANTIZED_FLOAT16:
return QuantizeWeightsFloat16(builder, input_model);
}
}
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model,
uint64_t weights_min_num_elements,
const CustomOpMap& custom_op_map) {
return QuantizeWeightsInternal(builder, input_model, true,
weights_min_num_elements, custom_op_map);
return QuantizeWeightsInt8(builder, input_model, true,
weights_min_num_elements, custom_op_map);
}
} // namespace optimize

View File

@ -26,6 +26,9 @@ limitations under the License.
namespace tflite {
namespace optimize {
// Supported resulting types from quantization process.
enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 };
// Quantizes input_model and populates the provided builder with the new model.
// By default only weights tensors weight more than 1024 elements will be
// quantized.
@ -33,8 +36,9 @@ namespace optimize {
// A tflite::Model can be obtained from the builder with:
// const uint8_t* buffer = builder->GetBufferPointer();
// tflite::Model* model = GetModel(buffer);
TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
const Model* input_model);
TfLiteStatus QuantizeWeights(
flatbuffers::FlatBufferBuilder* builder, const Model* input_model,
BufferType quant_type = BufferType::QUANTIZED_INT8);
// Same as above, but only weights with greater than or equal
// weights_min_num_elements elements will be quantized.

View File

@ -272,7 +272,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) {
} else if (quant_tensor->name()->str() == "conv_bias") {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
} else if (quant_tensor->buffer() != 0) {
// If its a non-bias constant tensor, is must be the weight.
// If it's a non-bias constant tensor, it must be the weight.
EXPECT_EQ(quant_tensor->type(), TensorType_INT8);
} else {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
@ -281,6 +281,65 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) {
}
}
TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) {
LoadBasicModel();
flatbuffers::FlatBufferBuilder builder;
auto status = tflite::optimize::QuantizeWeights(
&builder, model_, BufferType::QUANTIZED_FLOAT16);
EXPECT_EQ(status, kTfLiteOk);
const uint8_t* buffer = builder.GetBufferPointer();
const Model* output_model = GetModel(buffer);
ASSERT_TRUE(output_model);
ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size());
for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size();
++subgraph_idx) {
const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx);
const auto float_graph = model_->subgraphs()->Get(subgraph_idx);
// The output graph should have two extra tensors from the added dequantize
// op.
ASSERT_EQ(quantized_graph->tensors()->size(),
float_graph->tensors()->size() + 2);
// Check that a dequantize op exists.
int32_t dequant_input_idx = -1;
int32_t dequant_output_idx = -1;
for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) {
const auto op = quantized_graph->operators()->Get(i);
const uint32_t op_code_idx = op->opcode_index();
if (output_model->operator_codes()->Get(op_code_idx)->builtin_code() ==
BuiltinOperator_DEQUANTIZE) {
dequant_input_idx = op->inputs()->Get(0);
dequant_output_idx = op->outputs()->Get(0);
}
}
ASSERT_GT(dequant_input_idx, -1);
ASSERT_GT(dequant_output_idx, -1);
for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) {
const auto quant_tensor = quantized_graph->tensors()->Get(i);
// If the tensor is a weight, it should have type FLOAT16.
// If the tensor is a bias, it should have type FLOAT16.
// If the tensor is an input or output it should have type FLOAT32.
// The input to dequantize should be FLOAT16, and all other tensors should
// be FLOAT32.
if (i == dequant_input_idx) {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
} else if (i == dequant_output_idx) {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
} else if (IsModelInputOrOutput(output_model, i)) {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
} else if (quant_tensor->name()->str() == "conv_bias") {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
} else if (quant_tensor->buffer() != 0) {
// If it's a non-bias constant tensor, it must be the weight.
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16);
} else {
EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32);
}
}
}
}
TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) {
LoadSharedWeightsModel();
flatbuffers::FlatBufferBuilder builder;