Port the SVDF full integer recipe to TFLite Micro.
This version varies slightly from the current reference implementation in TFLite (original): 1.) All references to tensor_utils:: namespace are dropped due to the build incompatibility and size the include brings in (just like the rest of this port for float/hybrid-quant). 2.) Scratch tensors are re-worked into variable tensors. This is a temporary workaround until memory planning lands . 3.) An additional Tensor is required to provide pre-calculated scale values. These calculations are very expensive on low power device. PiperOrigin-RevId: 286278125 Change-Id: Ibbadb2f38a6c25b5550b4fced5b32c7a5b9420df
This commit is contained in:
parent
6ac005515f
commit
35e4344f92
tensorflow/lite/micro
@ -72,7 +72,7 @@ static inline void ApplyTimeWeightsBiasAndActivation(
|
||||
|
||||
// Initialize output with bias if provided.
|
||||
if (bias) {
|
||||
// TODO(kreeger): doc me - VectorBatchVectorAssign
|
||||
// VectorBatchVectorAssign
|
||||
const float* bias_data = GetTensorData<float>(bias);
|
||||
float* output_data = GetTensorData<float>(output);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
@ -95,10 +95,9 @@ static inline void ApplyTimeWeightsBiasAndActivation(
|
||||
float* scratch_ptr_batch = GetTensorData<float>(scratch) + b * num_filters;
|
||||
|
||||
// Reduction sum vector
|
||||
const float* input_vector_ptr = scratch_ptr_batch;
|
||||
for (int i = 0; i < num_units; ++i) {
|
||||
for (int j = 0; j < rank; j++) {
|
||||
output_ptr_batch[i] += *input_vector_ptr++;
|
||||
output_ptr_batch[i] += *scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -274,6 +273,150 @@ inline void EvalHybridSVDF(
|
||||
params->activation, activation_state, scratch, output);
|
||||
}
|
||||
|
||||
void EvalIntegerSVDF(
|
||||
TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor,
|
||||
const TfLiteTensor* weights_feature_tensor,
|
||||
const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor,
|
||||
const TfLiteSVDFParams* params, TfLiteTensor* activation_state_tensor,
|
||||
TfLiteTensor* output_tensor, TfLiteTensor* scratch_tensor,
|
||||
TfLiteTensor* scratch_output_tensor, int32_t scale_1_a, int scale_1_b,
|
||||
int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) {
|
||||
const int n_rank = params->rank;
|
||||
const int n_batch = input_tensor->dims->data[0];
|
||||
const int n_input = input_tensor->dims->data[1];
|
||||
const int n_filter = weights_feature_tensor->dims->data[0];
|
||||
const int n_unit = n_filter / n_rank;
|
||||
const int n_memory = weights_time_tensor->dims->data[1];
|
||||
|
||||
// Rewrite last bit of state.
|
||||
{
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int16_t* state_ptr_batch =
|
||||
GetTensorData<int16_t>(activation_state_tensor) +
|
||||
b * n_memory * n_filter;
|
||||
for (int c = 0; c < n_filter; ++c) {
|
||||
int16_t* state_ptr = state_ptr_batch + c * n_memory;
|
||||
state_ptr[n_memory - 1] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Feature matmul.
|
||||
{
|
||||
int16_t* state = GetTensorData<int16_t>(activation_state_tensor);
|
||||
const int8_t* input = GetTensorData<int8_t>(input_tensor);
|
||||
const int8_t* weight_feature =
|
||||
GetTensorData<int8_t>(weights_feature_tensor);
|
||||
const int32_t output_max = std::numeric_limits<int16_t>::max();
|
||||
const int32_t output_min = std::numeric_limits<int16_t>::min();
|
||||
int16_t* result_in_batch = state + (n_memory - 1);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
const int8_t* matrix_ptr = weight_feature;
|
||||
for (int r = 0; r < n_filter; r++) {
|
||||
int32_t dot_prod = 0;
|
||||
const int8_t* vector_in_batch = input + b * n_input;
|
||||
for (int c = 0; c < n_input; c++) {
|
||||
dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
|
||||
}
|
||||
dot_prod =
|
||||
MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
|
||||
dot_prod = std::min(std::max(output_min, dot_prod), output_max);
|
||||
*result_in_batch = dot_prod;
|
||||
result_in_batch += n_memory;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Time.
|
||||
{
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int32_t* scratch_ptr_batch =
|
||||
GetTensorData<int32_t>(scratch_tensor) + b * n_filter;
|
||||
|
||||
// Perform batched vector dot product:
|
||||
const int16_t* vector1_ptr = GetTensorData<int16_t>(weights_time_tensor);
|
||||
const int16_t* vector2_ptr =
|
||||
GetTensorData<int16_t>(activation_state_tensor) +
|
||||
b * n_memory * n_filter;
|
||||
|
||||
for (int i = 0; i < n_filter; i++) {
|
||||
*scratch_ptr_batch = 0;
|
||||
for (int j = 0; j < n_memory; j++) {
|
||||
*scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
|
||||
}
|
||||
scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce, add bias, rescale, activation.
|
||||
{
|
||||
int32_t* output_temp = GetTensorData<int32_t>(scratch_output_tensor);
|
||||
// Add bias.
|
||||
if (bias_tensor) {
|
||||
// Vector batch assign:
|
||||
const int32_t* bias_data = GetTensorData<int32_t>(bias_tensor);
|
||||
for (int i = 0; i < n_batch; ++i) {
|
||||
int32_t* output_ptr = output_temp + i * n_unit;
|
||||
const int32_t* bias_ptr = bias_data;
|
||||
for (int j = 0; j < n_unit; ++j) {
|
||||
*output_ptr++ = *bias_ptr++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t* output_ptr = output_temp;
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
*output_ptr++ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int32_t* output_temp_ptr = output_temp + b * n_unit;
|
||||
int32_t* scratch_ptr_batch =
|
||||
GetTensorData<int32_t>(scratch_tensor) + b * n_filter;
|
||||
|
||||
// Reduction sum vector
|
||||
for (int i = 0; i < n_unit; ++i) {
|
||||
for (int j = 0; j < n_rank; ++j) {
|
||||
output_temp_ptr[i] += *scratch_ptr_batch++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescale.
|
||||
const int32_t output_max = std::numeric_limits<int8_t>::max();
|
||||
const int32_t output_min = std::numeric_limits<int8_t>::min();
|
||||
for (int i = 0; i < n_batch * n_unit; ++i) {
|
||||
int32_t x1 = output_temp[i];
|
||||
int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
|
||||
int32_t x3 = x2 + output_zp;
|
||||
int32_t x4 = std::min(std::max(output_min, x3), output_max);
|
||||
GetTensorData<int8_t>(output_tensor)[i] = static_cast<int8_t>(x4);
|
||||
}
|
||||
}
|
||||
|
||||
// Shift state.
|
||||
{
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
int16_t* state_ptr_batch =
|
||||
GetTensorData<int16_t>(activation_state_tensor) +
|
||||
b * n_memory * n_filter;
|
||||
for (int f = 0; f < n_filter; ++f) {
|
||||
// Shift the vector left:
|
||||
int16_t* batch_ptr = state_ptr_batch;
|
||||
int16_t* batch_start = state_ptr_batch + 1;
|
||||
int16_t* batch_end = state_ptr_batch + n_memory;
|
||||
while (batch_start != batch_end) {
|
||||
*batch_ptr++ = *batch_start++;
|
||||
}
|
||||
state_ptr_batch[n_memory - 1] = 0;
|
||||
state_ptr_batch += n_memory;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Input tensors.
|
||||
@ -303,10 +446,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// [3] = Bias (optional), {1, num_units}
|
||||
// [4] = Activation State (variable),
|
||||
// {2, batch_size, memory_size * num_filters}
|
||||
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
|
||||
// has been implemented (cl/263032056)
|
||||
// TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kWeightsFeatureTensor);
|
||||
@ -325,10 +465,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int num_units = num_filters / rank;
|
||||
const int memory_size = weights_time->dims->data[1];
|
||||
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
const bool is_hybrid_op = IsHybridOp(input, weights_feature);
|
||||
const bool is_full_integer = input->type == kTfLiteInt8;
|
||||
|
||||
// Validate Input Tensor:
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
|
||||
|
||||
// Validate Tensor Output:
|
||||
// [0] = float/int8, {2, batch_size, num_units}
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
|
||||
|
||||
// Validate Weights Feature Input Tensor:
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
|
||||
TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
|
||||
@ -341,11 +494,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Validate Optional Bias Input Tensor:
|
||||
if (bias) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
// Validate Activation State Input Tensor:
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
|
||||
@ -354,26 +505,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Validate shared Scratch Tensor (same for full float and hybrid):
|
||||
// [0] = Holds dot-product of time-forward calculations in
|
||||
// ApplyTimeWeightsBiasAndActivation():
|
||||
// float, {2, batch_size, num_filters}
|
||||
// float/int32, {2, batch_size, num_filters}
|
||||
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
|
||||
// has been implemented (cl/263032056)
|
||||
// has been implemented (b/132070898)
|
||||
// TfLiteTensor* scratch_tensor = GetTemporary(context, node, 0);
|
||||
TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]];
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2);
|
||||
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[1], num_filters);
|
||||
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
const bool is_hybrid_op = IsHybridOp(input, weights_feature);
|
||||
// TODO(kreeger): Handle full quant svdf b/139435798
|
||||
if (is_hybrid_op) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
|
||||
|
||||
// Validate Input Tensor dtypes:
|
||||
TF_LITE_ENSURE(context, weights_feature->type == kTfLiteUInt8 ||
|
||||
weights_feature->type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE(context, weights_time->type == kTfLiteUInt8 ||
|
||||
weights_time->type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
|
||||
|
||||
if (bias) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
// Validate Scratch Tensors:
|
||||
// [0] = (shared - see above for usage)
|
||||
@ -385,6 +539,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* scratch_scaling_factors = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* scratch_float_weights_time = GetTemporary(context, node, 3);
|
||||
|
||||
// Validate shared scratch tensor type:
|
||||
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
|
||||
|
||||
// Validate Input Quantized Scratch Tensor:
|
||||
TF_LITE_ENSURE(context, scratch_input_quantized->type == kTfLiteUInt8 ||
|
||||
scratch_input_quantized->type == kTfLiteInt8);
|
||||
@ -412,37 +569,75 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// called. Use this time to do a one-time de-quantization copy of
|
||||
// the input values from the Weights Time tensor to the float weights time
|
||||
// scratch tensor.
|
||||
// TODO(kreeger): Consider doing this at model conversion time?
|
||||
// TODO(b/146029510): Consider doing this at model conversion time.
|
||||
SymmetricDequantize(GetTensorData<int8_t>(weights_time),
|
||||
NumElements(scratch_float_weights_time),
|
||||
weights_time->params.scale,
|
||||
GetTensorData<float>(scratch_float_weights_time));
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
} else if (is_full_integer) {
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 8);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
|
||||
|
||||
if (bias) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
|
||||
|
||||
// Validate Scratch Tensors:
|
||||
// [0] = (shared - see above for usage)
|
||||
// [1] = Output Temp, int8_t, {2, num_units, batch_size}
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented.
|
||||
/* TF_LITE_ENSURE_EQ(context, node->temporaries->size, 2); */
|
||||
|
||||
// Validate shared scratch tensor type:
|
||||
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteInt32);
|
||||
|
||||
// Validate Output Temp Scratch Tensor:
|
||||
TfLiteTensor* scratch_output = &context->tensors[node->inputs->data[6]];
|
||||
TF_LITE_ENSURE_EQ(context, scratch_output->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_output), 2);
|
||||
TF_LITE_ENSURE_EQ(context, scratch_output->dims->data[0], num_units);
|
||||
TF_LITE_ENSURE_EQ(context, scratch_output->dims->data[1], batch_size);
|
||||
|
||||
// Validate output tensor:
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
|
||||
|
||||
// Validate Input Tensor dtypes:
|
||||
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
|
||||
|
||||
if (bias) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
// Full-float SVDF only uses the one shared scratch tensor (see above for
|
||||
// usage).
|
||||
// TODO(kreeger): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented (cl/263032056)
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented.
|
||||
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
|
||||
}
|
||||
|
||||
// Validate Tensor Output:
|
||||
// [0] = float, {2, batch_size, num_units}
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
|
||||
// Validate shared scratch tensor type:
|
||||
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
|
||||
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
@ -451,15 +646,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetInput(context, node, kWeightsTimeTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
|
||||
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
|
||||
// has been implemented (cl/263032056)
|
||||
// TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented. TfLiteTensor* scratch =
|
||||
// GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]];
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
&context->tensors[node->inputs->data[kInputActivationStateTensor]];
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
const bool is_full_integer = input->type == kTfLiteInt8;
|
||||
|
||||
switch (weights_feature->type) {
|
||||
case kTfLiteFloat32: {
|
||||
EvalFloatSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
@ -470,19 +667,46 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
TfLiteTensor* scratch_input_quantized = GetTemporary(context, node, 1);
|
||||
TfLiteTensor* scratch_scaling_factors = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* scratch_float_weights_time = GetTemporary(context, node, 3);
|
||||
EvalHybridSVDF(context, node, input, weights_feature,
|
||||
scratch_float_weights_time, bias, params, scratch,
|
||||
scratch_scaling_factors, scratch_input_quantized,
|
||||
activation_state, output);
|
||||
return kTfLiteOk;
|
||||
if (is_full_integer) {
|
||||
// TODO(b/146029510): In order to prevent expensive scale calculations
|
||||
// during each eval of this Op, pre-calculated values are being stored
|
||||
// in a Tensor in the flatbuffer. Inside this Tensor, the 4 scale values
|
||||
// are stored in a int32 buffer.
|
||||
const TfLiteTensor* effective_scale_data_tensor =
|
||||
GetInput(context, node, 7);
|
||||
const int32_t* effective_scale_data =
|
||||
GetTensorData<int32_t>(effective_scale_data_tensor);
|
||||
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented TfLiteTensor*
|
||||
// output_temp = GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* output_temp = &context->tensors[node->inputs->data[6]];
|
||||
|
||||
// Currently supports only ReLU.
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
|
||||
EvalIntegerSVDF(context, node, input, weights_feature, weights_time,
|
||||
bias, params, activation_state, output, scratch,
|
||||
output_temp, effective_scale_data[0],
|
||||
effective_scale_data[1], effective_scale_data[2],
|
||||
effective_scale_data[3], input->params.zero_point,
|
||||
output->params.zero_point);
|
||||
return kTfLiteOk;
|
||||
} else {
|
||||
// Hybrid quantized:
|
||||
TfLiteTensor* scratch_input_quantized = GetTemporary(context, node, 1);
|
||||
TfLiteTensor* scratch_scaling_factors = GetTemporary(context, node, 2);
|
||||
TfLiteTensor* scratch_float_weights_time =
|
||||
GetTemporary(context, node, 3);
|
||||
EvalHybridSVDF(context, node, input, weights_feature,
|
||||
scratch_float_weights_time, bias, params, scratch,
|
||||
scratch_scaling_factors, scratch_input_quantized,
|
||||
activation_state, output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
// TODO(kreeger): Handle this case for full quant svdf b/139435798
|
||||
context->ReportError(context, "Type %s not currently supported.",
|
||||
TfLiteTypeGetName(weights_feature->type));
|
||||
return kTfLiteError;
|
||||
|
@ -146,7 +146,7 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
|
||||
// Bias is an optional tensor:
|
||||
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
|
||||
// has been implemented (cl/263032056)
|
||||
// has been implemented (b/132070898)
|
||||
// int inputs_array_data[] = {5, 0, 1, 2, kTfLiteOptionalTensor, 3};
|
||||
int inputs_array_data[] = {6, 0, 1, 2, kTfLiteOptionalTensor, 3, 5};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
@ -166,7 +166,6 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
node.outputs = outputs_array;
|
||||
if (is_hybrid_op) {
|
||||
node.temporaries = hybrid_temporaries_array;
|
||||
|
||||
} else {
|
||||
node.temporaries = temporaries_array;
|
||||
}
|
||||
@ -203,6 +202,81 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateIntegerSVDFGoldens(const int batch_size, const int num_units,
|
||||
const int input_size, const int rank,
|
||||
TfLiteTensor* tensors, const int tensor_count,
|
||||
int8_t* golden_input_data,
|
||||
const int golden_input_data_size,
|
||||
int8_t* output_data, int8_t* expected_output) {
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensor_count, &context);
|
||||
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_SVDF, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteSVDFParams params;
|
||||
params.rank = rank;
|
||||
params.activation = kTfLiteActRelu;
|
||||
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented. int inputs_array_data[] = {5, 0, 1, 2, 3,
|
||||
// 4};
|
||||
int inputs_array_data[] = {8, 0, 1, 2, 3, 4, 6, 7, 8};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
|
||||
int outputs_array_data[] = {1, 5};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
int temporaries_array_data[] = {2, 7, 8};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(¶ms);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
|
||||
int input_sequence_size =
|
||||
golden_input_data_size / sizeof(int8_t) / (input_size * batch_size);
|
||||
for (int i = 0; i < input_sequence_size; ++i) {
|
||||
int8_t* input_batch_start = golden_input_data + i * input_size * batch_size;
|
||||
int8_t* input_batch_end = input_batch_start + input_size * batch_size;
|
||||
int8_t* tensor_data = tensors[0].data.int8;
|
||||
while (input_batch_start != input_batch_end) {
|
||||
*tensor_data++ = *input_batch_start++;
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
int output_idx = 0;
|
||||
int golden_idx = i * batch_size * num_units;
|
||||
for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx], 0);
|
||||
output_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
}
|
||||
|
||||
void TestSVDF(const int batch_size, const int num_units, const int input_size,
|
||||
const int memory_size, const int rank, float* input_data,
|
||||
float* weights_feature_data, float* weights_time_data,
|
||||
@ -383,6 +457,88 @@ inline void TestHybridSVDFUint8(
|
||||
tolerance);
|
||||
}
|
||||
|
||||
inline void TestIntegerSVDF(
|
||||
const int batch_size, const int num_units, const int input_size,
|
||||
const int memory_size, const int rank, int8_t* input_data,
|
||||
float input_scale, int8_t* weights_feature_data,
|
||||
float weights_feature_scale, int16_t* weights_time_data,
|
||||
float weights_time_scale, int32_t* bias_data, float bias_scale,
|
||||
int16_t* activation_state_data, float activation_scale,
|
||||
int32_t* scratch_data, int32_t* scratch_output_data, int8_t* output_data,
|
||||
float output_scale, int32_t effective_scale_1_a,
|
||||
int32_t effective_scale_1_b, int32_t effective_scale_2_a,
|
||||
int32_t effective_scale_2_b, int8_t* golden_input_data,
|
||||
int golden_input_data_size, int8_t* expected_output) {
|
||||
const int num_filters = num_units * rank;
|
||||
|
||||
const int input_dims_arg[] = {2, batch_size, input_size};
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
|
||||
|
||||
const int weights_feature_dims_args[] = {2, num_filters, input_size};
|
||||
TfLiteIntArray* weights_feature_dims =
|
||||
IntArrayFromInts(weights_feature_dims_args);
|
||||
|
||||
const int weights_time_dims_args[] = {2, num_filters, memory_size};
|
||||
TfLiteIntArray* weights_time_dims = IntArrayFromInts(weights_time_dims_args);
|
||||
|
||||
const int bias_dims_data[] = {1, num_units};
|
||||
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
|
||||
|
||||
const int activation_state_dims_args[] = {2, batch_size,
|
||||
memory_size * num_filters};
|
||||
TfLiteIntArray* activation_state_dims =
|
||||
IntArrayFromInts(activation_state_dims_args);
|
||||
|
||||
// Scratch output is the same shape as output:
|
||||
const int scratch_dims_args[] = {2, batch_size, num_filters};
|
||||
TfLiteIntArray* scratch_dims = IntArrayFromInts(scratch_dims_args);
|
||||
|
||||
// Full integer requires one more scratch tensor:
|
||||
const int scratch_output_dims_args[] = {2, num_units, batch_size};
|
||||
TfLiteIntArray* scratch_output_dims =
|
||||
IntArrayFromInts(scratch_output_dims_args);
|
||||
|
||||
const int output_dims_args[] = {2, batch_size, num_units};
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
|
||||
|
||||
// Tensor size is higher due to workarounds in micro buffer usage
|
||||
// (b/132070898) and re-working scale calculations (b/146029510).
|
||||
const int tensor_count = 9; // 5 inputs, 1 output, 2 scratch, 1 temp
|
||||
|
||||
const int effective_scale_dims_args[] = {1, 4};
|
||||
int32_t effective_scale_data[] = {effective_scale_1_a, effective_scale_1_b,
|
||||
effective_scale_2_a, effective_scale_2_b};
|
||||
TfLiteIntArray* effective_scale_dims =
|
||||
IntArrayFromInts(effective_scale_dims_args);
|
||||
|
||||
TfLiteTensor tensors[] = {
|
||||
CreateQuantizedTensor(input_data, input_dims, input_scale,
|
||||
0 /* zero-point */, "input"),
|
||||
CreateQuantizedTensor(weights_feature_data, weights_feature_dims,
|
||||
weights_feature_scale, 0 /* zero-point */,
|
||||
"weights_feature"),
|
||||
CreateQuantizedTensor(weights_time_data, weights_time_dims,
|
||||
weights_time_scale, 0 /* zero-point */,
|
||||
"weights_time"),
|
||||
CreateQuantized32Tensor(bias_data, bias_dims, "bias", bias_scale),
|
||||
CreateQuantizedTensor(activation_state_data, activation_state_dims,
|
||||
activation_scale, 0 /* zero-point */,
|
||||
"activation_state", true /* is_variable */),
|
||||
CreateQuantizedTensor(output_data, output_dims, output_scale,
|
||||
0 /* zero-point */, "output"),
|
||||
CreateQuantized32Tensor(scratch_data, scratch_dims, "scratch",
|
||||
1.f /* scale-placeholder */),
|
||||
CreateQuantized32Tensor(scratch_output_data, scratch_output_dims,
|
||||
"scratch_output", 1.f /* scale-placeholder */),
|
||||
CreateTensor(effective_scale_data, effective_scale_dims,
|
||||
"effective_scale"),
|
||||
};
|
||||
|
||||
ValidateIntegerSVDFGoldens(
|
||||
batch_size, num_units, input_size, rank, tensors, tensor_count,
|
||||
golden_input_data, golden_input_data_size, output_data, expected_output);
|
||||
} // namespace
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
@ -754,4 +910,83 @@ TF_LITE_MICRO_TEST(BlackBoxTestHybridRank2Uint8) {
|
||||
tflite::testing::svdf_golden_output_rank_2, 0.00625109 /* tolerance */);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(BlackBoxTestIntegerRank1) {
|
||||
constexpr int batch_size = 2;
|
||||
constexpr int num_units = 4;
|
||||
constexpr int input_size = 3;
|
||||
constexpr int memory_size = 10;
|
||||
constexpr int rank = 1;
|
||||
constexpr int num_filters = num_units * rank;
|
||||
|
||||
int8_t weights_feature_data[] = {-81, -92, 2, 96, 57, 32,
|
||||
71, 70, 100, -92, -17, -27};
|
||||
const int weights_feature_dims_count = num_filters * input_size;
|
||||
|
||||
int16_t weights_time_data[] = {
|
||||
-10464, 12324, 9142, -11842, -11836, 7273, 9029, -2175, 260, 4067,
|
||||
12795, -3488, -3202, 5011, 12987, -887, 12875, 5171, 7185, 10174,
|
||||
-12098, 12461, -7072, 8870, 7739, 11447, 5954, 11765, -5733, 10643,
|
||||
-3534, 8912, 4693, -7761, -8886, -519, -4898, 5067, 3205, -1107,
|
||||
};
|
||||
const int weights_time_dims_count = num_filters * memory_size;
|
||||
|
||||
int32_t bias_data[] = {-409707, 641518, 1662434, -113372};
|
||||
|
||||
int8_t input_sequences_data[] = {
|
||||
64, 25, 34, 23, 68, -99, 16, -59, -114, 46, 47, 94,
|
||||
18, -128, -96, -73, 16, 96, 64, 25, 34, 23, 68, -99,
|
||||
16, -59, -114, 46, 47, 94, 18, -128, -96, -73, 16, 96,
|
||||
64, 25, 34, 23, 68, -99, 16, -59, -114, 46, 47, 94,
|
||||
18, -128, -96, -73, 16, 96, 64, 25, 34, 23, 68, -99,
|
||||
16, -59, -114, 46, 47, 94, 18, -128, -96, -73, 16, 96,
|
||||
};
|
||||
|
||||
int8_t expected_output[] = {
|
||||
-9, 24, 31, 1, -10, 10, -3, 0, 2, 4, -44, -7, -10, 32,
|
||||
52, 1, 12, -17, 9, -8, 7, 16, -11, -8, -26, 29, 28, 16,
|
||||
-23, 26, 30, -6, -8, -25, -86, -5, -44, 59, 81, 15, 62, -16,
|
||||
-37, 3, 27, 14, 34, -10, 1, 24, -25, 23, 31, 61, 67, 11,
|
||||
-64, -65, -128, -25, -53, 59, 127, 20, 20, -29, -20, -15, -28, 0,
|
||||
8, -27, 54, 61, -67, 38, 38, 64, 115, 0, -44, -75, -128, -20,
|
||||
-19, 93, 101, 35, -5, -56, 30, -18, -40, -9, -8, -31,
|
||||
};
|
||||
|
||||
const int input_size_dims_count = batch_size * input_size;
|
||||
int8_t input_data[input_size_dims_count];
|
||||
|
||||
const int activation_state_dims_count =
|
||||
batch_size * memory_size * num_filters;
|
||||
int16_t activation_state_data[activation_state_dims_count];
|
||||
|
||||
const int scratch_dims_count = batch_size * num_filters;
|
||||
int32_t scratch_data[scratch_dims_count];
|
||||
|
||||
const int scratch_output_dims_count = batch_size * num_units;
|
||||
int32_t scratch_output_data[scratch_output_dims_count];
|
||||
|
||||
const int output_dims_count = batch_size * num_units;
|
||||
int8_t output_data[output_dims_count];
|
||||
|
||||
float input_scale = 1.f / INT8_MAX; // Range is [-1, 1]
|
||||
float weights_feature_scale = 0.5 / INT8_MAX; // Range is [-0.5, 0.5]
|
||||
float weights_time_scale = 1 / INT16_MAX; // Range is [-1, 1]
|
||||
float activation_scale = 16.f / INT16_MAX; // Range is [-16, 16]
|
||||
float bias_scale = 512 / INT32_MAX; // Range is [-512, 512]
|
||||
float output_scale = 0.5f / INT8_MAX; // Range is [-0.5, 0.5]
|
||||
|
||||
int32_t effective_scale_1_a = 1082163456;
|
||||
int32_t effective_scale_1_b = -3;
|
||||
int32_t effective_scale_2_a = 2139160192;
|
||||
int32_t effective_scale_2_b = -18;
|
||||
|
||||
tflite::testing::TestIntegerSVDF(
|
||||
batch_size, num_units, input_size, memory_size, rank, input_data,
|
||||
input_scale, weights_feature_data, weights_feature_scale,
|
||||
weights_time_data, weights_time_scale, bias_data, bias_scale,
|
||||
activation_state_data, activation_scale, scratch_data,
|
||||
scratch_output_data, output_data, output_scale, effective_scale_1_a,
|
||||
effective_scale_1_b, effective_scale_2_a, effective_scale_2_b,
|
||||
input_sequences_data, sizeof(input_sequences_data), expected_output);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
||||
|
@ -27,13 +27,19 @@ namespace tflite {
|
||||
namespace {
|
||||
|
||||
static const uint8_t kAsymmetricUInt8Min = 0;
|
||||
static const uint8_t kAsymmetricUInt8Max = 255;
|
||||
static const uint8_t kAsymmetricUInt8Max = UINT8_MAX;
|
||||
static const uint8_t kSymmetricUInt8Min = 1;
|
||||
static const uint8_t kSymmetricUInt8Max = 255;
|
||||
static const int8_t kAsymmetricInt8Min = -128;
|
||||
static const int8_t kAsymmetricInt8Max = 127;
|
||||
static const uint8_t kSymmetricUInt8Max = UINT8_MAX;
|
||||
static const int8_t kAsymmetricInt8Min = INT8_MIN;
|
||||
static const int8_t kAsymmetricInt8Max = INT8_MAX;
|
||||
static const int kSymmetricInt8Scale = kAsymmetricInt8Max;
|
||||
|
||||
static const int16_t kAsymmetricInt16Max = INT16_MAX;
|
||||
static const int kSymmetricInt16Scale = kAsymmetricInt16Max;
|
||||
|
||||
static const int32_t kAsymmetricInt32Max = INT32_MAX;
|
||||
static const int kSymmetricInt32Scale = kAsymmetricInt32Max;
|
||||
|
||||
} // namespace
|
||||
|
||||
int ElementCount(const TfLiteIntArray& dims) {
|
||||
@ -187,6 +193,47 @@ void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
}
|
||||
}
|
||||
|
||||
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
int16_t* quantized_values, float* scaling_factor) {
|
||||
int input_size = ElementCount(*dims);
|
||||
|
||||
float min = 0;
|
||||
float max = 0;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
min = fminf(min, values[i]);
|
||||
max = fmaxf(max, values[i]);
|
||||
}
|
||||
*scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt16Scale;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
const int32_t quantized_value =
|
||||
static_cast<int32_t>(roundf(values[i] / *scaling_factor));
|
||||
// Clamp: just in case some odd numeric offset.
|
||||
quantized_values[i] = fminf(kSymmetricInt16Scale,
|
||||
fmaxf(-kSymmetricInt16Scale, quantized_value));
|
||||
}
|
||||
}
|
||||
|
||||
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
int32_t* quantized_values, float* scaling_factor) {
|
||||
int input_size = ElementCount(*dims);
|
||||
|
||||
float min = 0;
|
||||
float max = 0;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
min = fminf(min, values[i]);
|
||||
max = fmaxf(max, values[i]);
|
||||
}
|
||||
|
||||
*scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt32Scale;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
const int32_t quantized_value =
|
||||
static_cast<int32_t>(roundf(values[i] / *scaling_factor));
|
||||
// Clamp: just in case some odd numeric offset.
|
||||
quantized_values[i] = fminf(kSymmetricInt32Scale,
|
||||
fmaxf(-kSymmetricInt32Scale, quantized_value));
|
||||
}
|
||||
}
|
||||
|
||||
void SymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
uint8_t* quantized_values, float* scaling_factor) {
|
||||
SignedSymmetricQuantize(values, dims,
|
||||
|
@ -74,6 +74,12 @@ void SignedSymmetricPerChannelQuantize(const float* values,
|
||||
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
int8_t* quantized_values, float* scaling_factor);
|
||||
|
||||
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
int16_t* quantized_values, float* scaling_factor);
|
||||
|
||||
void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
int32_t* quantized_values, float* scaling_factor);
|
||||
|
||||
void SymmetricQuantize(const float* values, TfLiteIntArray* dims,
|
||||
uint8_t* quantized_values, float* scaling_factor);
|
||||
|
||||
|
@ -307,6 +307,18 @@ TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const int16_t* data, TfLiteIntArray* dims,
|
||||
float scale, int zero_point,
|
||||
const char* name, bool is_variable) {
|
||||
TfLiteTensor result = CreateTensor(dims, name, is_variable);
|
||||
result.type = kTfLiteInt16;
|
||||
result.data.i16 = const_cast<int16_t*>(data);
|
||||
result.params = {scale, zero_point};
|
||||
result.quantization = {kTfLiteAffineQuantization, nullptr};
|
||||
result.bytes = ElementCount(*dims) * sizeof(int16_t);
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const float* input, int8_t* quantized,
|
||||
TfLiteIntArray* dims, float scale,
|
||||
int zero_point, const char* name,
|
||||
|
@ -79,6 +79,10 @@ TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
|
||||
float scale, int zero_point,
|
||||
const char* name, bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const int16_t* data, TfLiteIntArray* dims,
|
||||
float scale, int zero_point,
|
||||
const char* name, bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const float* input, int8_t* quantized,
|
||||
TfLiteIntArray* dims, float scale,
|
||||
int zero_point, const char* name,
|
||||
|
@ -215,6 +215,24 @@ inline TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data,
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(float* data, int16_t* quantized_data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name,
|
||||
bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.i16 = quantized_data;
|
||||
result.type = kTfLiteInt16;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 0;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int16_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float scale,
|
||||
|
Loading…
Reference in New Issue
Block a user