Port the SVDF full integer recipe to TFLite Micro.

This version varies slightly from the current reference implementation in TFLite (original):

1.) All references to tensor_utils:: namespace are dropped due to the build incompatibility and size the include brings in (just like the rest of this port for float/hybrid-quant).

2.) Scratch tensors are re-worked into variable tensors. This is a temporary workaround until memory planning lands .

3.) An additional Tensor is required to provide pre-calculated scale values. These calculations are very expensive on low power device.

PiperOrigin-RevId: 286278125
Change-Id: Ibbadb2f38a6c25b5550b4fced5b32c7a5b9420df
This commit is contained in:
Nick Kreeger 2019-12-18 15:31:19 -08:00 committed by TensorFlower Gardener
parent 6ac005515f
commit 35e4344f92
7 changed files with 593 additions and 47 deletions

View File

@ -72,7 +72,7 @@ static inline void ApplyTimeWeightsBiasAndActivation(
// Initialize output with bias if provided.
if (bias) {
// TODO(kreeger): doc me - VectorBatchVectorAssign
// VectorBatchVectorAssign
const float* bias_data = GetTensorData<float>(bias);
float* output_data = GetTensorData<float>(output);
for (int i = 0; i < batch_size; ++i) {
@ -95,10 +95,9 @@ static inline void ApplyTimeWeightsBiasAndActivation(
float* scratch_ptr_batch = GetTensorData<float>(scratch) + b * num_filters;
// Reduction sum vector
const float* input_vector_ptr = scratch_ptr_batch;
for (int i = 0; i < num_units; ++i) {
for (int j = 0; j < rank; j++) {
output_ptr_batch[i] += *input_vector_ptr++;
output_ptr_batch[i] += *scratch_ptr_batch++;
}
}
}
@ -274,6 +273,150 @@ inline void EvalHybridSVDF(
params->activation, activation_state, scratch, output);
}
void EvalIntegerSVDF(
TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor,
const TfLiteTensor* weights_feature_tensor,
const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor,
const TfLiteSVDFParams* params, TfLiteTensor* activation_state_tensor,
TfLiteTensor* output_tensor, TfLiteTensor* scratch_tensor,
TfLiteTensor* scratch_output_tensor, int32_t scale_1_a, int scale_1_b,
int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
const int n_filter = weights_feature_tensor->dims->data[0];
const int n_unit = n_filter / n_rank;
const int n_memory = weights_time_tensor->dims->data[1];
// Rewrite last bit of state.
{
for (int b = 0; b < n_batch; ++b) {
int16_t* state_ptr_batch =
GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int c = 0; c < n_filter; ++c) {
int16_t* state_ptr = state_ptr_batch + c * n_memory;
state_ptr[n_memory - 1] = 0;
}
}
}
// Feature matmul.
{
int16_t* state = GetTensorData<int16_t>(activation_state_tensor);
const int8_t* input = GetTensorData<int8_t>(input_tensor);
const int8_t* weight_feature =
GetTensorData<int8_t>(weights_feature_tensor);
const int32_t output_max = std::numeric_limits<int16_t>::max();
const int32_t output_min = std::numeric_limits<int16_t>::min();
int16_t* result_in_batch = state + (n_memory - 1);
for (int b = 0; b < n_batch; b++) {
const int8_t* matrix_ptr = weight_feature;
for (int r = 0; r < n_filter; r++) {
int32_t dot_prod = 0;
const int8_t* vector_in_batch = input + b * n_input;
for (int c = 0; c < n_input; c++) {
dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
}
dot_prod =
MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
*result_in_batch = dot_prod;
result_in_batch += n_memory;
}
}
}
// Time.
{
for (int b = 0; b < n_batch; ++b) {
int32_t* scratch_ptr_batch =
GetTensorData<int32_t>(scratch_tensor) + b * n_filter;
// Perform batched vector dot product:
const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor);
const int16_t* vector2_ptr =
GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int i = 0; i < n_filter; i++) {
*scratch_ptr_batch = 0;
for (int j = 0; j < n_memory; j++) {
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
}
scratch_ptr_batch++;
}
}
}
// Reduce, add bias, rescale, activation.
{
int32_t* output_temp = GetTensorData<int32_t>(scratch_output_tensor);
// Add bias.
if (bias_tensor) {
// Vector batch assign:
const int32_t* bias_data = GetTensorData<int32_t>(bias_tensor);
for (int i = 0; i < n_batch; ++i) {
int32_t* output_ptr = output_temp + i * n_unit;
const int32_t* bias_ptr = bias_data;
for (int j = 0; j < n_unit; ++j) {
*output_ptr++ = *bias_ptr++;
}
}
} else {
int32_t* output_ptr = output_temp;
for (int i = 0; i < n_batch * n_unit; ++i) {
*output_ptr++ = 0;
}
}
// Reduce.
for (int b = 0; b < n_batch; ++b) {
int32_t* output_temp_ptr = output_temp + b * n_unit;
int32_t* scratch_ptr_batch =
GetTensorData<int32_t>(scratch_tensor) + b * n_filter;
// Reduction sum vector
for (int i = 0; i < n_unit; ++i) {
for (int j = 0; j < n_rank; ++j) {
output_temp_ptr[i] += *scratch_ptr_batch++;
}
}
}
// Rescale.
const int32_t output_max = std::numeric_limits<int8_t>::max();
const int32_t output_min = std::numeric_limits<int8_t>::min();
for (int i = 0; i < n_batch * n_unit; ++i) {
int32_t x1 = output_temp[i];
int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
int32_t x3 = x2 + output_zp;
int32_t x4 = std::min(std::max(output_min, x3), output_max);
GetTensorData<int8_t>(output_tensor)[i] = static_cast<int8_t>(x4);
}
}
// Shift state.
{
for (int b = 0; b < n_batch; ++b) {
int16_t* state_ptr_batch =
GetTensorData<int16_t>(activation_state_tensor) +
b * n_memory * n_filter;
for (int f = 0; f < n_filter; ++f) {
// Shift the vector left:
int16_t* batch_ptr = state_ptr_batch;
int16_t* batch_start = state_ptr_batch + 1;
int16_t* batch_end = state_ptr_batch + n_memory;
while (batch_start != batch_end) {
*batch_ptr++ = *batch_start++;
}
state_ptr_batch[n_memory - 1] = 0;
state_ptr_batch += n_memory;
}
}
}
}
} // namespace
// Input tensors.
@ -303,10 +446,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// [3] = Bias (optional), {1, num_units}
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
@ -325,10 +465,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int num_units = num_filters / rank;
const int memory_size = weights_time->dims->data[1];
// The weights are of consistent type, so it suffices to check one.
const bool is_hybrid_op = IsHybridOp(input, weights_feature);
const bool is_full_integer = input->type == kTfLiteInt8;
// Validate Input Tensor:
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
// Validate Tensor Output:
// [0] = float/int8, {2, batch_size, num_units}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
// Validate Weights Feature Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
@ -341,11 +494,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Validate Optional Bias Input Tensor:
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
// Validate Activation State Input Tensor:
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
@ -354,26 +505,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Validate shared Scratch Tensor (same for full float and hybrid):
// [0] = Holds dot-product of time-forward calculations in
// ApplyTimeWeightsBiasAndActivation():
// float, {2, batch_size, num_filters}
// float/int32, {2, batch_size, num_filters}
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// has been implemented (b/132070898)
// TfLiteTensor* scratch_tensor = GetTemporary(context, node, 0);
TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]];
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2);
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[1], num_filters);
// The weights are of consistent type, so it suffices to check one.
const bool is_hybrid_op = IsHybridOp(input, weights_feature);
// TODO(kreeger): Handle full quant svdf b/139435798
if (is_hybrid_op) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
// Validate Input Tensor dtypes:
TF_LITE_ENSURE(context, weights_feature->type == kTfLiteUInt8 ||
weights_feature->type == kTfLiteInt8);
TF_LITE_ENSURE(context, weights_time->type == kTfLiteUInt8 ||
weights_time->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
// Validate Scratch Tensors:
// [0] = (shared - see above for usage)
@ -385,6 +539,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* scratch_scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* scratch_float_weights_time = GetTemporary(context, node, 3);
// Validate shared scratch tensor type:
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
// Validate Input Quantized Scratch Tensor:
TF_LITE_ENSURE(context, scratch_input_quantized->type == kTfLiteUInt8 ||
scratch_input_quantized->type == kTfLiteInt8);
@ -412,37 +569,75 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// called. Use this time to do a one-time de-quantization copy of
// the input values from the Weights Time tensor to the float weights time
// scratch tensor.
// TODO(kreeger): Consider doing this at model conversion time?
// TODO(b/146029510): Consider doing this at model conversion time.
SymmetricDequantize(GetTensorData<int8_t>(weights_time),
NumElements(scratch_float_weights_time),
weights_time->params.scale,
GetTensorData<float>(scratch_float_weights_time));
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
} else if (is_full_integer) {
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented
TF_LITE_ENSURE_EQ(context, node->inputs->size, 8);
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
}
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
// Validate Scratch Tensors:
// [0] = (shared - see above for usage)
// [1] = Output Temp, int8_t, {2, num_units, batch_size}
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented.
/* TF_LITE_ENSURE_EQ(context, node->temporaries->size, 2); */
// Validate shared scratch tensor type:
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteInt32);
// Validate Output Temp Scratch Tensor:
TfLiteTensor* scratch_output = &context->tensors[node->inputs->data[6]];
TF_LITE_ENSURE_EQ(context, scratch_output->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_output), 2);
TF_LITE_ENSURE_EQ(context, scratch_output->dims->data[0], num_units);
TF_LITE_ENSURE_EQ(context, scratch_output->dims->data[1], batch_size);
// Validate output tensor:
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
} else {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
// Validate Input Tensor dtypes:
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
if (bias) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
// Full-float SVDF only uses the one shared scratch tensor (see above for
// usage).
// TODO(kreeger): Use input tensor as variable until scratch tensor
// allocation has been implemented (cl/263032056)
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented.
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
}
// Validate Tensor Output:
// [0] = float, {2, batch_size, num_units}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
// Validate shared scratch tensor type:
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
@ -451,15 +646,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented. TfLiteTensor* scratch =
// GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]];
TfLiteTensor* activation_state =
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const bool is_full_integer = input->type == kTfLiteInt8;
switch (weights_feature->type) {
case kTfLiteFloat32: {
EvalFloatSVDF(context, node, input, weights_feature, weights_time, bias,
@ -470,19 +667,46 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
case kTfLiteInt8: {
TfLiteTensor* scratch_input_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scratch_scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* scratch_float_weights_time = GetTemporary(context, node, 3);
EvalHybridSVDF(context, node, input, weights_feature,
scratch_float_weights_time, bias, params, scratch,
scratch_scaling_factors, scratch_input_quantized,
activation_state, output);
return kTfLiteOk;
if (is_full_integer) {
// TODO(b/146029510): In order to prevent expensive scale calculations
// during each eval of this Op, pre-calculated values are being stored
// in a Tensor in the flatbuffer. Inside this Tensor, the 4 scale values
// are stored in a int32 buffer.
const TfLiteTensor* effective_scale_data_tensor =
GetInput(context, node, 7);
const int32_t* effective_scale_data =
GetTensorData<int32_t>(effective_scale_data_tensor);
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented TfLiteTensor*
// output_temp = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* output_temp = &context->tensors[node->inputs->data[6]];
// Currently supports only ReLU.
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
EvalIntegerSVDF(context, node, input, weights_feature, weights_time,
bias, params, activation_state, output, scratch,
output_temp, effective_scale_data[0],
effective_scale_data[1], effective_scale_data[2],
effective_scale_data[3], input->params.zero_point,
output->params.zero_point);
return kTfLiteOk;
} else {
// Hybrid quantized:
TfLiteTensor* scratch_input_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scratch_scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* scratch_float_weights_time =
GetTemporary(context, node, 3);
EvalHybridSVDF(context, node, input, weights_feature,
scratch_float_weights_time, bias, params, scratch,
scratch_scaling_factors, scratch_input_quantized,
activation_state, output);
return kTfLiteOk;
}
break;
}
default:
// TODO(kreeger): Handle this case for full quant svdf b/139435798
context->ReportError(context, "Type %s not currently supported.",
TfLiteTypeGetName(weights_feature->type));
return kTfLiteError;

View File

@ -146,7 +146,7 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
// Bias is an optional tensor:
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// has been implemented (b/132070898)
// int inputs_array_data[] = {5, 0, 1, 2, kTfLiteOptionalTensor, 3};
int inputs_array_data[] = {6, 0, 1, 2, kTfLiteOptionalTensor, 3, 5};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
@ -166,7 +166,6 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
node.outputs = outputs_array;
if (is_hybrid_op) {
node.temporaries = hybrid_temporaries_array;
} else {
node.temporaries = temporaries_array;
}
@ -203,6 +202,81 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
}
}
void ValidateIntegerSVDFGoldens(const int batch_size, const int num_units,
const int input_size, const int rank,
TfLiteTensor* tensors, const int tensor_count,
int8_t* golden_input_data,
const int golden_input_data_size,
int8_t* output_data, int8_t* expected_output) {
TfLiteContext context;
PopulateContext(tensors, tensor_count, &context);
::tflite::ops::micro::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SVDF, 1);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSVDFParams params;
params.rank = rank;
params.activation = kTfLiteActRelu;
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented. int inputs_array_data[] = {5, 0, 1, 2, 3,
// 4};
int inputs_array_data[] = {8, 0, 1, 2, 3, 4, 6, 7, 8};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 5};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {2, 7, 8};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
int input_sequence_size =
golden_input_data_size / sizeof(int8_t) / (input_size * batch_size);
for (int i = 0; i < input_sequence_size; ++i) {
int8_t* input_batch_start = golden_input_data + i * input_size * batch_size;
int8_t* input_batch_end = input_batch_start + input_size * batch_size;
int8_t* tensor_data = tensors[0].data.int8;
while (input_batch_start != input_batch_end) {
*tensor_data++ = *input_batch_start++;
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
int output_idx = 0;
int golden_idx = i * batch_size * num_units;
for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx], 0);
output_idx++;
}
}
if (registration->free) {
registration->free(&context, user_data);
}
}
void TestSVDF(const int batch_size, const int num_units, const int input_size,
const int memory_size, const int rank, float* input_data,
float* weights_feature_data, float* weights_time_data,
@ -383,6 +457,88 @@ inline void TestHybridSVDFUint8(
tolerance);
}
inline void TestIntegerSVDF(
const int batch_size, const int num_units, const int input_size,
const int memory_size, const int rank, int8_t* input_data,
float input_scale, int8_t* weights_feature_data,
float weights_feature_scale, int16_t* weights_time_data,
float weights_time_scale, int32_t* bias_data, float bias_scale,
int16_t* activation_state_data, float activation_scale,
int32_t* scratch_data, int32_t* scratch_output_data, int8_t* output_data,
float output_scale, int32_t effective_scale_1_a,
int32_t effective_scale_1_b, int32_t effective_scale_2_a,
int32_t effective_scale_2_b, int8_t* golden_input_data,
int golden_input_data_size, int8_t* expected_output) {
const int num_filters = num_units * rank;
const int input_dims_arg[] = {2, batch_size, input_size};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
const int weights_feature_dims_args[] = {2, num_filters, input_size};
TfLiteIntArray* weights_feature_dims =
IntArrayFromInts(weights_feature_dims_args);
const int weights_time_dims_args[] = {2, num_filters, memory_size};
TfLiteIntArray* weights_time_dims = IntArrayFromInts(weights_time_dims_args);
const int bias_dims_data[] = {1, num_units};
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
const int activation_state_dims_args[] = {2, batch_size,
memory_size * num_filters};
TfLiteIntArray* activation_state_dims =
IntArrayFromInts(activation_state_dims_args);
// Scratch output is the same shape as output:
const int scratch_dims_args[] = {2, batch_size, num_filters};
TfLiteIntArray* scratch_dims = IntArrayFromInts(scratch_dims_args);
// Full integer requires one more scratch tensor:
const int scratch_output_dims_args[] = {2, num_units, batch_size};
TfLiteIntArray* scratch_output_dims =
IntArrayFromInts(scratch_output_dims_args);
const int output_dims_args[] = {2, batch_size, num_units};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
// Tensor size is higher due to workarounds in micro buffer usage
// (b/132070898) and re-working scale calculations (b/146029510).
const int tensor_count = 9; // 5 inputs, 1 output, 2 scratch, 1 temp
const int effective_scale_dims_args[] = {1, 4};
int32_t effective_scale_data[] = {effective_scale_1_a, effective_scale_1_b,
effective_scale_2_a, effective_scale_2_b};
TfLiteIntArray* effective_scale_dims =
IntArrayFromInts(effective_scale_dims_args);
TfLiteTensor tensors[] = {
CreateQuantizedTensor(input_data, input_dims, input_scale,
0 /* zero-point */, "input"),
CreateQuantizedTensor(weights_feature_data, weights_feature_dims,
weights_feature_scale, 0 /* zero-point */,
"weights_feature"),
CreateQuantizedTensor(weights_time_data, weights_time_dims,
weights_time_scale, 0 /* zero-point */,
"weights_time"),
CreateQuantized32Tensor(bias_data, bias_dims, "bias", bias_scale),
CreateQuantizedTensor(activation_state_data, activation_state_dims,
activation_scale, 0 /* zero-point */,
"activation_state", true /* is_variable */),
CreateQuantizedTensor(output_data, output_dims, output_scale,
0 /* zero-point */, "output"),
CreateQuantized32Tensor(scratch_data, scratch_dims, "scratch",
1.f /* scale-placeholder */),
CreateQuantized32Tensor(scratch_output_data, scratch_output_dims,
"scratch_output", 1.f /* scale-placeholder */),
CreateTensor(effective_scale_data, effective_scale_dims,
"effective_scale"),
};
ValidateIntegerSVDFGoldens(
batch_size, num_units, input_size, rank, tensors, tensor_count,
golden_input_data, golden_input_data_size, output_data, expected_output);
} // namespace
} // namespace
} // namespace testing
} // namespace tflite
@ -754,4 +910,83 @@ TF_LITE_MICRO_TEST(BlackBoxTestHybridRank2Uint8) {
tflite::testing::svdf_golden_output_rank_2, 0.00625109 /* tolerance */);
}
TF_LITE_MICRO_TEST(BlackBoxTestIntegerRank1) {
constexpr int batch_size = 2;
constexpr int num_units = 4;
constexpr int input_size = 3;
constexpr int memory_size = 10;
constexpr int rank = 1;
constexpr int num_filters = num_units * rank;
int8_t weights_feature_data[] = {-81, -92, 2, 96, 57, 32,
71, 70, 100, -92, -17, -27};
const int weights_feature_dims_count = num_filters * input_size;
int16_t weights_time_data[] = {
-10464, 12324, 9142, -11842, -11836, 7273, 9029, -2175, 260, 4067,
12795, -3488, -3202, 5011, 12987, -887, 12875, 5171, 7185, 10174,
-12098, 12461, -7072, 8870, 7739, 11447, 5954, 11765, -5733, 10643,
-3534, 8912, 4693, -7761, -8886, -519, -4898, 5067, 3205, -1107,
};
const int weights_time_dims_count = num_filters * memory_size;
int32_t bias_data[] = {-409707, 641518, 1662434, -113372};
int8_t input_sequences_data[] = {
64, 25, 34, 23, 68, -99, 16, -59, -114, 46, 47, 94,
18, -128, -96, -73, 16, 96, 64, 25, 34, 23, 68, -99,
16, -59, -114, 46, 47, 94, 18, -128, -96, -73, 16, 96,
64, 25, 34, 23, 68, -99, 16, -59, -114, 46, 47, 94,
18, -128, -96, -73, 16, 96, 64, 25, 34, 23, 68, -99,
16, -59, -114, 46, 47, 94, 18, -128, -96, -73, 16, 96,
};
int8_t expected_output[] = {
-9, 24, 31, 1, -10, 10, -3, 0, 2, 4, -44, -7, -10, 32,
52, 1, 12, -17, 9, -8, 7, 16, -11, -8, -26, 29, 28, 16,
-23, 26, 30, -6, -8, -25, -86, -5, -44, 59, 81, 15, 62, -16,
-37, 3, 27, 14, 34, -10, 1, 24, -25, 23, 31, 61, 67, 11,
-64, -65, -128, -25, -53, 59, 127, 20, 20, -29, -20, -15, -28, 0,
8, -27, 54, 61, -67, 38, 38, 64, 115, 0, -44, -75, -128, -20,
-19, 93, 101, 35, -5, -56, 30, -18, -40, -9, -8, -31,
};
const int input_size_dims_count = batch_size * input_size;
int8_t input_data[input_size_dims_count];
const int activation_state_dims_count =
batch_size * memory_size * num_filters;
int16_t activation_state_data[activation_state_dims_count];
const int scratch_dims_count = batch_size * num_filters;
int32_t scratch_data[scratch_dims_count];
const int scratch_output_dims_count = batch_size * num_units;
int32_t scratch_output_data[scratch_output_dims_count];
const int output_dims_count = batch_size * num_units;
int8_t output_data[output_dims_count];
float input_scale = 1.f / INT8_MAX; // Range is [-1, 1]
float weights_feature_scale = 0.5 / INT8_MAX; // Range is [-0.5, 0.5]
float weights_time_scale = 1 / INT16_MAX; // Range is [-1, 1]
float activation_scale = 16.f / INT16_MAX; // Range is [-16, 16]
float bias_scale = 512 / INT32_MAX; // Range is [-512, 512]
float output_scale = 0.5f / INT8_MAX; // Range is [-0.5, 0.5]
int32_t effective_scale_1_a = 1082163456;
int32_t effective_scale_1_b = -3;
int32_t effective_scale_2_a = 2139160192;
int32_t effective_scale_2_b = -18;
tflite::testing::TestIntegerSVDF(
batch_size, num_units, input_size, memory_size, rank, input_data,
input_scale, weights_feature_data, weights_feature_scale,
weights_time_data, weights_time_scale, bias_data, bias_scale,
activation_state_data, activation_scale, scratch_data,
scratch_output_data, output_data, output_scale, effective_scale_1_a,
effective_scale_1_b, effective_scale_2_a, effective_scale_2_b,
input_sequences_data, sizeof(input_sequences_data), expected_output);
}
TF_LITE_MICRO_TESTS_END

View File

@ -27,13 +27,19 @@ namespace tflite {
namespace {
static const uint8_t kAsymmetricUInt8Min = 0;
static const uint8_t kAsymmetricUInt8Max = 255;
static const uint8_t kAsymmetricUInt8Max = UINT8_MAX;
static const uint8_t kSymmetricUInt8Min = 1;
static const uint8_t kSymmetricUInt8Max = 255;
static const int8_t kAsymmetricInt8Min = -128;
static const int8_t kAsymmetricInt8Max = 127;
static const uint8_t kSymmetricUInt8Max = UINT8_MAX;
static const int8_t kAsymmetricInt8Min = INT8_MIN;
static const int8_t kAsymmetricInt8Max = INT8_MAX;
static const int kSymmetricInt8Scale = kAsymmetricInt8Max;
static const int16_t kAsymmetricInt16Max = INT16_MAX;
static const int kSymmetricInt16Scale = kAsymmetricInt16Max;
static const int32_t kAsymmetricInt32Max = INT32_MAX;
static const int kSymmetricInt32Scale = kAsymmetricInt32Max;
} // namespace
int ElementCount(const TfLiteIntArray& dims) {
@ -187,6 +193,47 @@ void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
}
}
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
int16_t* quantized_values, float* scaling_factor) {
int input_size = ElementCount(*dims);
float min = 0;
float max = 0;
for (int i = 0; i < input_size; i++) {
min = fminf(min, values[i]);
max = fmaxf(max, values[i]);
}
*scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt16Scale;
for (int i = 0; i < input_size; i++) {
const int32_t quantized_value =
static_cast<int32_t>(roundf(values[i] / *scaling_factor));
// Clamp: just in case some odd numeric offset.
quantized_values[i] = fminf(kSymmetricInt16Scale,
fmaxf(-kSymmetricInt16Scale, quantized_value));
}
}
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
int32_t* quantized_values, float* scaling_factor) {
int input_size = ElementCount(*dims);
float min = 0;
float max = 0;
for (int i = 0; i < input_size; i++) {
min = fminf(min, values[i]);
max = fmaxf(max, values[i]);
}
*scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt32Scale;
for (int i = 0; i < input_size; i++) {
const int32_t quantized_value =
static_cast<int32_t>(roundf(values[i] / *scaling_factor));
// Clamp: just in case some odd numeric offset.
quantized_values[i] = fminf(kSymmetricInt32Scale,
fmaxf(-kSymmetricInt32Scale, quantized_value));
}
}
void SymmetricQuantize(const float* values, TfLiteIntArray* dims,
uint8_t* quantized_values, float* scaling_factor) {
SignedSymmetricQuantize(values, dims,

View File

@ -74,6 +74,12 @@ void SignedSymmetricPerChannelQuantize(const float* values,
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
int8_t* quantized_values, float* scaling_factor);
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
int16_t* quantized_values, float* scaling_factor);
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
int32_t* quantized_values, float* scaling_factor);
void SymmetricQuantize(const float* values, TfLiteIntArray* dims,
uint8_t* quantized_values, float* scaling_factor);

View File

@ -307,6 +307,18 @@ TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
return result;
}
TfLiteTensor CreateQuantizedTensor(const int16_t* data, TfLiteIntArray* dims,
float scale, int zero_point,
const char* name, bool is_variable) {
TfLiteTensor result = CreateTensor(dims, name, is_variable);
result.type = kTfLiteInt16;
result.data.i16 = const_cast<int16_t*>(data);
result.params = {scale, zero_point};
result.quantization = {kTfLiteAffineQuantization, nullptr};
result.bytes = ElementCount(*dims) * sizeof(int16_t);
return result;
}
TfLiteTensor CreateQuantizedTensor(const float* input, int8_t* quantized,
TfLiteIntArray* dims, float scale,
int zero_point, const char* name,

View File

@ -79,6 +79,10 @@ TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
float scale, int zero_point,
const char* name, bool is_variable = false);
TfLiteTensor CreateQuantizedTensor(const int16_t* data, TfLiteIntArray* dims,
float scale, int zero_point,
const char* name, bool is_variable = false);
TfLiteTensor CreateQuantizedTensor(const float* input, int8_t* quantized,
TfLiteIntArray* dims, float scale,
int zero_point, const char* name,

View File

@ -215,6 +215,24 @@ inline TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data,
return result;
}
inline TfLiteTensor CreateQuantizedTensor(float* data, int16_t* quantized_data,
TfLiteIntArray* dims,
const char* name,
bool is_variable = false) {
TfLiteTensor result;
SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
result.data.i16 = quantized_data;
result.type = kTfLiteInt16;
result.dims = dims;
result.params.zero_point = 0;
result.allocation_type = kTfLiteMemNone;
result.bytes = ElementCount(*dims) * sizeof(int16_t);
result.allocation = nullptr;
result.name = name;
result.is_variable = is_variable;
return result;
}
inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data,
TfLiteIntArray* dims,
const char* name, float scale,