Enable affine quantized tensor in writer_lib.

PiperOrigin-RevId: 281766372
Change-Id: I65d2dacdf4ea2a5df61406abd0b35bdffda2571e
This commit is contained in:
Mihai Maruseac 2019-11-21 09:30:25 -08:00 committed by TensorFlower Gardener
parent 080ecf5149
commit 82e32167b6

View File

@ -164,38 +164,19 @@ InterpreterWriter::ExportTensors(flatbuffers::FlatBufferBuilder* fbb) {
// Primitive type.
TensorType type = TfLiteTypeToSchemaType(tensor->type);
// Handle quantization
flatbuffers::Offset<QuantizationParameters> quantization_params;
const flatbuffers::Offset<flatbuffers::Vector<float>> null_array;
flatbuffers::Offset<flatbuffers::Vector<float>> scale_array;
flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point_array;
// Multi channel quantization.
if (tensor->quantization.type == kTfLiteAffineQuantization) {
const TfLiteAffineQuantization* params =
reinterpret_cast<TfLiteAffineQuantization*>(
tensor->quantization.params);
const size_t num_scales = params->scale->size;
const int channel_index = params->quantized_dimension;
std::vector<float> scale_vector(
{params->scale->data, params->scale->data + num_scales});
std::vector<int64_t> zero_point_vector(
{params->zero_point->data, params->zero_point->data + num_scales});
scale_array = fbb->CreateVector<float>(scale_vector);
zero_point_array = fbb->CreateVector<int64_t>(zero_point_vector);
quantization_params = CreateQuantizationParameters(
*fbb, null_array, null_array, scale_array, zero_point_array,
QuantizationDetails_NONE, 0, channel_index);
} else {
// Quantization with a single argument array.
if (tensor->params.scale != 0.f) {
// We have quantization, make a single arugment array (multi channel
// quant needs updating here).
scale_array = fbb->CreateVector<float>({tensor->params.scale});
zero_point_array =
fbb->CreateVector<int64_t>({tensor->params.zero_point});
}
quantization_params = CreateQuantizationParameters(
*fbb, null_array, null_array, scale_array, zero_point_array);
}
flatbuffers::Offset<QuantizationParameters> quantization_params =
CreateQuantizationParameters(*fbb, null_array, null_array,
scale_array, zero_point_array);
// Shape
TfLiteIntArrayView shape_view(tensor->dims);
std::vector<int> shape =