Add delegate support for QUANTIZED_16BIT_LSTM
PiperOrigin-RevId: 259914993
This commit is contained in:
parent
f3a7982794
commit
1ffdcbe96a
@ -18,6 +18,8 @@ cc_library(
|
||||
],
|
||||
"//conditions:default": [
|
||||
"nnapi_delegate.cc",
|
||||
"quant_lstm_sup.h",
|
||||
"quant_lstm_sup.cc",
|
||||
],
|
||||
}),
|
||||
hdrs = ["nnapi_delegate.h"],
|
||||
@ -51,4 +53,22 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "quant_lstm_sup_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"quant_lstm_sup.cc",
|
||||
"quant_lstm_sup.h",
|
||||
"quant_lstm_sup_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":nnapi_delegate",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
@ -19,10 +19,12 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
@ -31,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
@ -154,6 +157,22 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int kLstmFullKernelInputSize = 24;
|
||||
// The 20 input version is deprecated and kept only to
|
||||
// support old model. The latest version of the LSTM Full Kernel
|
||||
// is the one with 24 inputs
|
||||
constexpr int kLstmFullKernelNoOptionalParamsInputSize = 20;
|
||||
constexpr int kLstmBasicKernelInputSize = 5;
|
||||
|
||||
inline bool isLstmBasicKernel(const TfLiteNode* node) {
|
||||
return node->inputs->size == kLstmBasicKernelInputSize;
|
||||
}
|
||||
|
||||
inline bool isLstmFullKernel(const TfLiteNode* node) {
|
||||
return node->inputs->size == kLstmFullKernelInputSize ||
|
||||
node->inputs->size == kLstmFullKernelNoOptionalParamsInputSize;
|
||||
}
|
||||
|
||||
bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
|
||||
const TfLiteNode* node) {
|
||||
switch (builtin_code) {
|
||||
@ -165,7 +184,15 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
|
||||
const TfLiteType filter_type = context->tensors[filter_id].type;
|
||||
return IsFloat(input_type) && IsQuantized(filter_type);
|
||||
}
|
||||
case kTfLiteBuiltinLstm:
|
||||
case kTfLiteBuiltinLstm: {
|
||||
const int input_id = node->inputs->data[0];
|
||||
// Input #1 is optional so use #2 to determine if hybrid.
|
||||
const int weights_id = node->inputs->data[2];
|
||||
const TfLiteType input_type = context->tensors[input_id].type;
|
||||
const TfLiteType weights_type = context->tensors[weights_id].type;
|
||||
return isLstmFullKernel(node) && IsFloat(input_type) &&
|
||||
IsQuantized(weights_type);
|
||||
}
|
||||
case kTfLiteBuiltinUnidirectionalSequenceLstm: {
|
||||
const int input_id = node->inputs->data[0];
|
||||
// Input #1 is optional so use #2 to determine if hybrid.
|
||||
@ -356,6 +383,13 @@ class OperandMapping {
|
||||
// be mapped.
|
||||
int add_new_non_tensor_operand() { return next_ann_tensor_index_++; }
|
||||
|
||||
// This call is necessary for input operands generated by the delegate
|
||||
// to map constant inputs not present in TFLite but required by NNAPI,
|
||||
// for example when splitting one input in several ones.
|
||||
int add_delegate_generated_input_ann_tensors_operand() {
|
||||
return next_ann_tensor_index_++;
|
||||
}
|
||||
|
||||
// Add a new mapping from `tflite_index` and return the NN API tensor index.
|
||||
int add_new_ann_tensor_index(int tflite_index) {
|
||||
if (tflite_index >= lite_tensor_to_ann_tensor_.size()) {
|
||||
@ -581,6 +615,66 @@ class NNAPIOpBuilder {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus AddNewInputConstantTensor(
|
||||
int32_t nn_type, TfLiteType type, const TfLiteIntArray* dims,
|
||||
const std::function<TfLiteStatus(TfLitePtrUnion, int64_t)>& init_fn,
|
||||
const TfLiteQuantizationParams& quant_params, int* tensor_index) {
|
||||
TF_LITE_ENSURE_OK(context_,
|
||||
context_->AddTensors(context_, 1, tensor_index));
|
||||
|
||||
TfLiteTensor* new_tensor = &context_->tensors[*tensor_index];
|
||||
new_tensor->type = type;
|
||||
new_tensor->allocation_type = kTfLiteDynamic;
|
||||
new_tensor->params = quant_params;
|
||||
|
||||
// Not removing the new tensor in case of resizing errors since it will
|
||||
// be cleared by the context
|
||||
TF_LITE_ENSURE_OK(
|
||||
context_,
|
||||
context_->ResizeTensor(
|
||||
context_, new_tensor,
|
||||
// Resize Tensor takes ownership of the dims array passed as param
|
||||
TfLiteIntArrayCopy(dims)));
|
||||
|
||||
const int64_t out_size = NumElements(dims);
|
||||
TF_LITE_ENSURE_OK(context_, init_fn(new_tensor->data, out_size));
|
||||
|
||||
const uint32_t tensor_rank = static_cast<uint32_t>(dims->size);
|
||||
const uint32_t* tensor_dims = reinterpret_cast<const uint32_t*>(dims->data);
|
||||
ANeuralNetworksOperandType operand_type{nn_type, tensor_rank, tensor_dims,
|
||||
quant_params.scale,
|
||||
quant_params.zero_point};
|
||||
|
||||
const int ann_tensor_index =
|
||||
operand_mapping_->add_delegate_generated_input_ann_tensors_operand();
|
||||
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context_,
|
||||
nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
|
||||
|
||||
augmented_inputs_.push_back(ann_tensor_index);
|
||||
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context_, nnapi_->ANeuralNetworksModel_setOperandValue(
|
||||
nn_model_, ann_tensor_index, new_tensor->data.raw,
|
||||
new_tensor->bytes));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus AddNewInputConstantTensor(
|
||||
int32_t nn_type, TfLiteType type, std::initializer_list<int> dims,
|
||||
const std::function<TfLiteStatus(TfLitePtrUnion, int64_t)>& init_fn,
|
||||
const TfLiteQuantizationParams& quant_params, int* tensor_index) {
|
||||
TfLiteIntArray* dim_array = TfLiteIntArrayCreate(dims.size());
|
||||
const auto result = AddNewInputConstantTensor<T>(
|
||||
nn_type, type, dim_array, init_fn, quant_params, tensor_index);
|
||||
TfLiteIntArrayFree(dim_array);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns a TF Lite type which has the same memory representation as a
|
||||
// provided NN API type.
|
||||
@ -716,6 +810,11 @@ class NNAPIOpBuilder {
|
||||
case kTfLiteBool:
|
||||
nn_type = ANEURALNETWORKS_TENSOR_BOOL8;
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
nn_type = ANEURALNETWORKS_TENSOR_QUANT16_SYMM;
|
||||
scale = tensor->params.scale;
|
||||
zeroPoint = tensor->params.zero_point;
|
||||
break;
|
||||
default:
|
||||
context_->ReportError(
|
||||
context_, "Failed to add NN API tensor: type %s is not supported.",
|
||||
@ -839,6 +938,7 @@ struct NNAPIOpMappingArgs {
|
||||
TfLiteNode* node;
|
||||
std::vector<int>* model_state_outputs;
|
||||
std::vector<int>* model_state_tfl_inputs;
|
||||
std::vector<std::tuple<int, int>>* feedback_loops;
|
||||
};
|
||||
|
||||
// Mapping function simply returning the operation type without adding any
|
||||
@ -1665,20 +1765,246 @@ class NNAPIDelegateKernel {
|
||||
// Hybrid operators not supported before NNAPI 1.2.
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(levp): name the constants for number of inputs in LSTM kernel.
|
||||
if (node->inputs->size != 20 && node->inputs->size != 24) {
|
||||
return nullptr;
|
||||
|
||||
const auto weight_input_index =
|
||||
isLstmBasicKernel(node)
|
||||
? 2 /* basic::kInputWeights */
|
||||
: 4 /* full::kInputToOutputWeightsTensor */;
|
||||
|
||||
const TfLiteType weight_type =
|
||||
context->tensors[node->inputs->data[weight_input_index]].type;
|
||||
|
||||
if (isLstmBasicKernel(node)) {
|
||||
if (weight_type != kTfLiteUInt8) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto input_quantization_params =
|
||||
context->tensors[node->inputs->data[0]].params;
|
||||
if (input_quantization_params.scale != 1. / 128. ||
|
||||
input_quantization_params.zero_point != 128) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto output_quantization_params =
|
||||
context->tensors[node->outputs->data[0]].params;
|
||||
if (output_quantization_params.scale != 1. / 128. ||
|
||||
output_quantization_params.zero_point != 128) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto cell_state_quantization_params =
|
||||
context->tensors[node->outputs->data[1]].params;
|
||||
if (cell_state_quantization_params.scale != 16. / 32768. ||
|
||||
cell_state_quantization_params.zero_point != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto is_const_tensor = [&node, &context](int tensor_idx) {
|
||||
return context->tensors[node->inputs->data[tensor_idx]]
|
||||
.allocation_type == kTfLiteMmapRo;
|
||||
};
|
||||
|
||||
if (!is_const_tensor(2 /* kInputWeights */)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!is_const_tensor(3 /* kInputBiases */)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||
-> ANeuralNetworksOperationType {
|
||||
const auto output_dims =
|
||||
mapping_args.context
|
||||
->tensors[mapping_args.node->outputs->data[1]]
|
||||
.dims;
|
||||
|
||||
// Inputs kInputData
|
||||
mapping_args.builder->AddTensorInput(
|
||||
mapping_args.node->inputs->data[0 /* kInputData */],
|
||||
/* hybrid_op */ false,
|
||||
/* scalar_as_tensor */ false);
|
||||
|
||||
// The 8 weights tensors are set decomposing the
|
||||
// kInputWeights param
|
||||
const auto weight_tensor =
|
||||
mapping_args.context->tensors
|
||||
[mapping_args.node->inputs->data[2 /* kInputWeights */]];
|
||||
|
||||
std::vector<uint8_t> recurrent_to_input;
|
||||
std::vector<uint8_t> input_to_input;
|
||||
std::vector<uint8_t> recurrent_to_cell;
|
||||
std::vector<uint8_t> input_to_cell;
|
||||
std::vector<uint8_t> recurrent_to_forget;
|
||||
std::vector<uint8_t> input_to_forget;
|
||||
std::vector<uint8_t> recurrent_to_output;
|
||||
std::vector<uint8_t> input_to_output;
|
||||
tflite::delegate::nnapi::DecomposeQuantLstmWeightsTensor(
|
||||
weight_tensor.data.uint8, weight_tensor.dims,
|
||||
&recurrent_to_input, &input_to_input, &recurrent_to_cell,
|
||||
&input_to_cell, &recurrent_to_forget, &input_to_forget,
|
||||
&recurrent_to_output, &input_to_output);
|
||||
|
||||
const auto ui8_fill_with =
|
||||
[](const std::vector<uint8_t>& read_from,
|
||||
TfLitePtrUnion write_to, int64_t size) -> TfLiteStatus {
|
||||
std::copy(read_from.begin(), read_from.end(), write_to.uint8);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
|
||||
TfLiteIntArray* recurrent_weight_dims = TfLiteIntArrayCreate(2);
|
||||
TfLiteIntArray* input_weight_dims = TfLiteIntArrayCreate(2);
|
||||
tflite::delegate::nnapi::SetWeightSubmatrixDims(
|
||||
weight_tensor.dims, recurrent_weight_dims, input_weight_dims);
|
||||
|
||||
int new_tensor_index = -1;
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
input_weight_dims,
|
||||
std::bind(ui8_fill_with, input_to_input,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
input_weight_dims,
|
||||
std::bind(ui8_fill_with, input_to_forget,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
input_weight_dims,
|
||||
std::bind(ui8_fill_with, input_to_cell, std::placeholders::_1,
|
||||
std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
input_weight_dims,
|
||||
std::bind(ui8_fill_with, input_to_output,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
recurrent_weight_dims,
|
||||
std::bind(ui8_fill_with, recurrent_to_input,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
recurrent_weight_dims,
|
||||
std::bind(ui8_fill_with, recurrent_to_forget,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
recurrent_weight_dims,
|
||||
std::bind(ui8_fill_with, recurrent_to_cell,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
mapping_args.builder->AddNewInputConstantTensor<uint8_t>(
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, kTfLiteUInt8,
|
||||
recurrent_weight_dims,
|
||||
std::bind(ui8_fill_with, recurrent_to_output,
|
||||
std::placeholders::_1, std::placeholders::_2),
|
||||
weight_tensor.params, &new_tensor_index);
|
||||
|
||||
TfLiteIntArrayFree(input_weight_dims);
|
||||
TfLiteIntArrayFree(recurrent_weight_dims);
|
||||
|
||||
// Biases have to be split in four
|
||||
const auto i32_fill_with =
|
||||
[](const std::vector<int32_t>& read_from,
|
||||
TfLitePtrUnion write_to, int64_t size) -> TfLiteStatus {
|
||||
std::copy(read_from.begin(), read_from.end(), write_to.i32);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
|
||||
const auto bias_size = output_dims->data[1];
|
||||
const TfLiteTensor& biases_tensor =
|
||||
mapping_args.context->tensors
|
||||
[mapping_args.node->inputs->data[3 /* kInputBiases */]];
|
||||
|
||||
std::vector<int32_t> input_bias;
|
||||
std::vector<int32_t> cell_bias;
|
||||
std::vector<int32_t> forget_bias;
|
||||
std::vector<int32_t> output_bias;
|
||||
delegate::nnapi::DecomposeBiasTensor(
|
||||
biases_tensor.data.i32, bias_size, &input_bias, &cell_bias,
|
||||
&forget_bias, &output_bias);
|
||||
|
||||
int input_bias_tensor = -1;
|
||||
mapping_args.builder->AddNewInputConstantTensor<int32_t>(
|
||||
ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
|
||||
std::bind(i32_fill_with, input_bias, std::placeholders::_1,
|
||||
std::placeholders::_2),
|
||||
biases_tensor.params, &input_bias_tensor);
|
||||
// kForgetGateBiasTensor
|
||||
int forget_bias_tensor = -1;
|
||||
mapping_args.builder->AddNewInputConstantTensor<int32_t>(
|
||||
ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
|
||||
std::bind(i32_fill_with, forget_bias, std::placeholders::_1,
|
||||
std::placeholders::_2),
|
||||
biases_tensor.params, &forget_bias_tensor);
|
||||
// kCellGateBiasTensor
|
||||
int cell_gate_bias_tensor = -1;
|
||||
mapping_args.builder->AddNewInputConstantTensor<int32_t>(
|
||||
ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
|
||||
std::bind(i32_fill_with, cell_bias, std::placeholders::_1,
|
||||
std::placeholders::_2),
|
||||
biases_tensor.params, &cell_gate_bias_tensor);
|
||||
// kOutputGateBiasTensor
|
||||
int output_gate_bias_tensor = -1;
|
||||
mapping_args.builder->AddNewInputConstantTensor<int32_t>(
|
||||
ANEURALNETWORKS_TENSOR_INT32, kTfLiteInt32, {bias_size},
|
||||
std::bind(i32_fill_with, output_bias, std::placeholders::_1,
|
||||
std::placeholders::_2),
|
||||
biases_tensor.params, &output_gate_bias_tensor);
|
||||
|
||||
mapping_args.builder->AddTensorInput(
|
||||
mapping_args.node->inputs->data[4 /* kInputPrevState */],
|
||||
/* hybrid_op */ false,
|
||||
/* scalar_as_tensor */ false);
|
||||
|
||||
// kInputPrevActivation
|
||||
mapping_args.builder->AddTensorInput(
|
||||
mapping_args.node->inputs->data[1 /* kInputPrevActivation */],
|
||||
/* hybrid_op */ false,
|
||||
/* scalar_as_tensor */ false);
|
||||
|
||||
// Configuring the copy from the activation, state outputs
|
||||
// to their associated inputs
|
||||
mapping_args.feedback_loops->push_back(std::make_tuple(
|
||||
0 /*kOutputActivation*/, 1 /*kInputPrevActivation*/));
|
||||
|
||||
mapping_args.feedback_loops->push_back(
|
||||
std::make_tuple(1 /*kOutputState*/, 4 /*kInputPrevState*/));
|
||||
|
||||
// OUTPUTS
|
||||
// Setting only the first two since the remaining ones are
|
||||
// ignored by NNAPI
|
||||
mapping_args.builder->AddTensorOutput(
|
||||
mapping_args.node->outputs->data[1 /* kOutputState */], 0);
|
||||
|
||||
mapping_args.builder->AddTensorOutput(
|
||||
mapping_args.node->outputs
|
||||
->data[0 /* kOutputkOutputActivationState */],
|
||||
0);
|
||||
|
||||
return ANEURALNETWORKS_QUANTIZED_16BIT_LSTM;
|
||||
};
|
||||
}
|
||||
if (node->inputs->size == 24 &&
|
||||
android_sdk_version < kMinSdkVersionForNNAPI12) {
|
||||
// LSTM with layer norm introduced in API level 29
|
||||
return nullptr;
|
||||
}
|
||||
const TfLiteType weight_type =
|
||||
context
|
||||
->tensors[node->inputs
|
||||
->data[/*kInputToOutputWeightsTensor*/ 4]]
|
||||
.type;
|
||||
if (weight_type != kTfLiteFloat32 && weight_type != kTfLiteUInt8) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -2358,6 +2684,11 @@ class NNAPIDelegateKernel {
|
||||
int relative_output_index = 0;
|
||||
size_t output_offset = 0;
|
||||
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
|
||||
// If the NNAPI implementation doesn't have some of the outputs
|
||||
// they are left unmapped and we should not try to read their value here
|
||||
if (operand_mapping_.lite_index_to_ann(output_index) == -1) {
|
||||
continue;
|
||||
}
|
||||
TfLiteTensor* tensor = &context->tensors[output_index];
|
||||
if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
|
||||
tensor->buffer_handle < tensor_memory_map_->size()) {
|
||||
@ -2432,6 +2763,20 @@ class NNAPIDelegateKernel {
|
||||
output_offset += getNumPaddingBytes(tensor->bytes);
|
||||
}
|
||||
|
||||
// copy output of all output tensors in feedback_loops_ into the
|
||||
// associated input
|
||||
for (auto feedback_loop : feedback_loops_) {
|
||||
int output_tensor_idx;
|
||||
int input_tensor_idx;
|
||||
std::tie(output_tensor_idx, input_tensor_idx) = feedback_loop;
|
||||
TfLiteTensor* src =
|
||||
&context->tensors[node->outputs->data[output_tensor_idx]];
|
||||
TfLiteTensor* dest =
|
||||
&context->tensors[node->inputs->data[input_tensor_idx]];
|
||||
|
||||
memcpy(dest->data.raw, src->data.raw, src->bytes);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -2456,6 +2801,10 @@ class NNAPIDelegateKernel {
|
||||
tensor_memory_map_;
|
||||
std::vector<int> model_state_outputs_;
|
||||
std::vector<int> model_state_tfl_inputs_;
|
||||
// This is the equivalent of the pair model_state_outputs_,
|
||||
// model_state_tfl_inputs_ for all tensors where we have to keep the output
|
||||
// data available for TFLite model users
|
||||
std::vector<std::tuple<int, int>> feedback_loops_;
|
||||
|
||||
std::unique_ptr<NNMemory> nn_input_memory_;
|
||||
std::unique_ptr<NNMemory> nn_output_memory_;
|
||||
@ -2552,13 +2901,19 @@ class NNAPIDelegateKernel {
|
||||
input_tensor_flags | NN_TENSOR_FLAG_INT8_CONVERSION));
|
||||
continue;
|
||||
}
|
||||
if (reg->builtin_code == kTfLiteBuiltinLstm && input_pos >= 20) {
|
||||
if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmFullKernel(node) &&
|
||||
input_pos >= 20) {
|
||||
// Skip layer normalization weights. They are added in the Map
|
||||
// function (after all the other inputs added there) since layer
|
||||
// normalization weights are the last four inputs of the LSTM op in
|
||||
// NNAPI.
|
||||
continue;
|
||||
}
|
||||
if (reg->builtin_code == kTfLiteBuiltinLstm &&
|
||||
isLstmBasicKernel(node)) {
|
||||
// Configuring all inputs in the Map function
|
||||
continue;
|
||||
}
|
||||
if (reg->builtin_code == kTfLiteBuiltinUnidirectionalSequenceLstm) {
|
||||
if (input_pos >= 20) {
|
||||
// Skip layer normalization weights. They are added in the Map
|
||||
@ -2694,13 +3049,21 @@ class NNAPIDelegateKernel {
|
||||
int nn_op_type = Map(
|
||||
context, reg->builtin_code, reg->version, nnapi_->android_sdk_version,
|
||||
node)({context, &builder, node, &model_state_outputs_,
|
||||
&model_state_tfl_inputs_});
|
||||
&model_state_tfl_inputs_, &feedback_loops_});
|
||||
// Map outputs to NN API tensor indices.
|
||||
int output_tensor_flags = 0;
|
||||
if (need_int8_conversion) {
|
||||
output_tensor_flags |= NN_TENSOR_FLAG_INT8_CONVERSION;
|
||||
}
|
||||
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
|
||||
for (int output_pos = 0; output_pos < node->outputs->size; ++output_pos) {
|
||||
const auto output_index = node->outputs->data[output_pos];
|
||||
|
||||
// Outputs for basic LSTM cell are set in the Map function since
|
||||
if (reg->builtin_code == kTfLiteBuiltinLstm &&
|
||||
isLstmBasicKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
builder.AddTensorOutput(output_index, output_tensor_flags));
|
||||
}
|
||||
@ -2731,7 +3094,10 @@ class NNAPIDelegateKernel {
|
||||
for (int i : TfLiteIntArrayView(input_tensors)) {
|
||||
// Constant tensors are not NNAPI inputs.
|
||||
if (i != kOptionalTensor &&
|
||||
context->tensors[i].allocation_type != kTfLiteMmapRo) {
|
||||
context->tensors[i].allocation_type != kTfLiteMmapRo &&
|
||||
// The delegate might not have mapped this input (this can
|
||||
// happen if one tensor is split in several ones)
|
||||
operand_mapping_.lite_index_to_ann(i) != -1) {
|
||||
inputs.push_back(operand_mapping_.lite_index_to_ann(i));
|
||||
if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
|
||||
continue;
|
||||
@ -2754,7 +3120,11 @@ class NNAPIDelegateKernel {
|
||||
|
||||
size_t total_output_byte_size = 0;
|
||||
for (int i : TfLiteIntArrayView(output_tensors)) {
|
||||
outputs.push_back(operand_mapping_.lite_index_to_ann(i));
|
||||
const int output_tensor_ann_index = operand_mapping_.lite_index_to_ann(i);
|
||||
// Unmapped outputs are not added
|
||||
if (output_tensor_ann_index != -1) {
|
||||
outputs.push_back(output_tensor_ann_index);
|
||||
}
|
||||
if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
|
||||
continue;
|
||||
}
|
||||
|
153
tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc
Normal file
153
tensorflow/lite/delegates/nnapi/quant_lstm_sup.cc
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegate {
|
||||
namespace nnapi {
|
||||
|
||||
// The function extracts a submatrix of the weights at a given row
|
||||
// and column offsets from a 2D matrix
|
||||
void ExtractQuantLstmWeightsSubmatrix(const TfLiteIntArray* submatrix_dims,
|
||||
const int32_t offset_row,
|
||||
const int32_t offset_column,
|
||||
const TfLiteIntArray* weight_dims,
|
||||
const uint8_t* weights,
|
||||
std::vector<uint8_t>* submatrix) {
|
||||
auto const& submatrix_rows = submatrix_dims->data[0];
|
||||
auto const& submatrix_cols = submatrix_dims->data[1];
|
||||
auto const& weight_cols = weight_dims->data[1];
|
||||
|
||||
submatrix->resize(NumElements(submatrix_dims));
|
||||
|
||||
for (uint32_t i = 0; i < submatrix_rows * submatrix_cols; ++i) {
|
||||
const uint32_t row = i / submatrix_cols;
|
||||
const uint32_t column = i % submatrix_cols;
|
||||
(*submatrix)[i] =
|
||||
weights[(row + offset_row) * weight_cols + column + offset_column];
|
||||
}
|
||||
}
|
||||
|
||||
inline int OutputDepth(const TfLiteIntArray* weight_dims) {
|
||||
return weight_dims->data[0] / 4;
|
||||
}
|
||||
|
||||
inline int InputDepth(const TfLiteIntArray* weight_dims) {
|
||||
return weight_dims->data[1] - OutputDepth(weight_dims);
|
||||
}
|
||||
|
||||
void SetWeightSubmatrixDims(const TfLiteIntArray* weight_dims,
|
||||
TfLiteIntArray* recurrent_submatrix_dims,
|
||||
TfLiteIntArray* input_submatrix_dims) {
|
||||
const auto input_depth = InputDepth(weight_dims);
|
||||
const auto output_depth = OutputDepth(weight_dims);
|
||||
|
||||
recurrent_submatrix_dims->data[0] = output_depth;
|
||||
recurrent_submatrix_dims->data[1] = output_depth;
|
||||
|
||||
input_submatrix_dims->data[0] = output_depth;
|
||||
input_submatrix_dims->data[1] = input_depth;
|
||||
}
|
||||
|
||||
// Doing exactly the opposite work of QuantizedLSTMCell::concatenateWeights
|
||||
// in NNAPI, decomposing the concat_weights tensor data into its 8 components
|
||||
// according to the following diagram
|
||||
//
|
||||
// +-----------------------------------+
|
||||
// | recurrentToInput | inputToInput |
|
||||
// |-------------------+---------------|
|
||||
// | recurrentToCell | inputToCell |
|
||||
// |-------------------+---------------|
|
||||
// | recurrentToForget | inputToForget |
|
||||
// |-------------------+---------------|
|
||||
// | recurrentToOutput | inputToOutput |
|
||||
// +-----------------------------------+
|
||||
void DecomposeQuantLstmWeightsTensor(const uint8_t* concat_weights,
|
||||
const TfLiteIntArray* weight_dims,
|
||||
std::vector<uint8_t>* recurrent_to_input,
|
||||
std::vector<uint8_t>* input_to_input,
|
||||
std::vector<uint8_t>* recurrent_to_cell,
|
||||
std::vector<uint8_t>* input_to_cell,
|
||||
std::vector<uint8_t>* recurrent_to_forget,
|
||||
std::vector<uint8_t>* input_to_forget,
|
||||
std::vector<uint8_t>* recurrent_to_output,
|
||||
std::vector<uint8_t>* input_to_output) {
|
||||
const auto output_depth = OutputDepth(weight_dims);
|
||||
|
||||
TfLiteIntArray* recurrent_submatrix_dims = TfLiteIntArrayCreate(2);
|
||||
TfLiteIntArray* input_submatrix_dims = TfLiteIntArrayCreate(2);
|
||||
SetWeightSubmatrixDims(weight_dims, recurrent_submatrix_dims,
|
||||
input_submatrix_dims);
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 0 * output_depth,
|
||||
0, weight_dims, concat_weights,
|
||||
recurrent_to_input);
|
||||
ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 0 * output_depth,
|
||||
output_depth, weight_dims, concat_weights,
|
||||
input_to_input);
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 1 * output_depth,
|
||||
0, weight_dims, concat_weights,
|
||||
recurrent_to_cell);
|
||||
ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 1 * output_depth,
|
||||
output_depth, weight_dims, concat_weights,
|
||||
input_to_cell);
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 2 * output_depth,
|
||||
0, weight_dims, concat_weights,
|
||||
recurrent_to_forget);
|
||||
ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 2 * output_depth,
|
||||
output_depth, weight_dims, concat_weights,
|
||||
input_to_forget);
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(recurrent_submatrix_dims, 3 * output_depth,
|
||||
0, weight_dims, concat_weights,
|
||||
recurrent_to_output);
|
||||
ExtractQuantLstmWeightsSubmatrix(input_submatrix_dims, 3 * output_depth,
|
||||
output_depth, weight_dims, concat_weights,
|
||||
input_to_output);
|
||||
|
||||
TfLiteIntArrayFree(recurrent_submatrix_dims);
|
||||
TfLiteIntArrayFree(input_submatrix_dims);
|
||||
}
|
||||
|
||||
void DecomposeBiasTensor(const int32_t* biases, int bias_size,
|
||||
std::vector<int32_t>* input_bias,
|
||||
std::vector<int32_t>* cell_bias,
|
||||
std::vector<int32_t>* forget_bias,
|
||||
std::vector<int32_t>* output_bias) {
|
||||
input_bias->resize(bias_size);
|
||||
std::copy(biases, biases + bias_size, input_bias->begin());
|
||||
|
||||
cell_bias->resize(bias_size);
|
||||
std::copy(biases + bias_size, biases + 2 * bias_size, cell_bias->begin());
|
||||
|
||||
forget_bias->resize(bias_size);
|
||||
std::copy(biases + 2 * bias_size, biases + 3 * bias_size,
|
||||
forget_bias->begin());
|
||||
|
||||
output_bias->resize(bias_size);
|
||||
std::copy(biases + 3 * bias_size, biases + 4 * bias_size,
|
||||
output_bias->begin());
|
||||
}
|
||||
|
||||
} // namespace nnapi
|
||||
} // namespace delegate
|
||||
} // namespace tflite
|
58
tensorflow/lite/delegates/nnapi/quant_lstm_sup.h
Normal file
58
tensorflow/lite/delegates/nnapi/quant_lstm_sup.h
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_QUANT_LSTM_SUP_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_NNAPI_QUANT_LSTM_SUP_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegate {
|
||||
namespace nnapi {
|
||||
|
||||
void ExtractQuantLstmWeightsSubmatrix(const TfLiteIntArray* submatrix_dims,
|
||||
const int32_t offset_row,
|
||||
const int32_t offset_column,
|
||||
const TfLiteIntArray* weight_dims,
|
||||
const uint8_t* weights,
|
||||
std::vector<uint8_t>* submatrix);
|
||||
|
||||
void DecomposeQuantLstmWeightsTensor(const uint8_t* concat_weights,
|
||||
const TfLiteIntArray* weight_dims,
|
||||
std::vector<uint8_t>* recurrent_to_input,
|
||||
std::vector<uint8_t>* input_to_input,
|
||||
std::vector<uint8_t>* recurrent_to_cell,
|
||||
std::vector<uint8_t>* input_to_cell,
|
||||
std::vector<uint8_t>* recurrent_to_forget,
|
||||
std::vector<uint8_t>* input_to_forget,
|
||||
std::vector<uint8_t>* recurrent_to_output,
|
||||
std::vector<uint8_t>* input_to_output);
|
||||
|
||||
void SetWeightSubmatrixDims(const TfLiteIntArray* weight_dims,
|
||||
TfLiteIntArray* recurrent_submatrix_dims,
|
||||
TfLiteIntArray* input_submatrix_dims);
|
||||
|
||||
void DecomposeBiasTensor(const int32_t* biases, int bias_size,
|
||||
std::vector<int32_t>* input_bias,
|
||||
std::vector<int32_t>* cell_bias,
|
||||
std::vector<int32_t>* forget_bias,
|
||||
std::vector<int32_t>* output_bias);
|
||||
|
||||
} // namespace nnapi
|
||||
} // namespace delegate
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_NNAPI_QUANT_LSTM_SUP_H_
|
344
tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc
Normal file
344
tensorflow/lite/delegates/nnapi/quant_lstm_sup_test.cc
Normal file
@ -0,0 +1,344 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Test;
|
||||
|
||||
class DimsAllocatingTest : public Test {
|
||||
protected:
|
||||
DimsAllocatingTest() : allocated_dims_() {}
|
||||
|
||||
~DimsAllocatingTest() override {
|
||||
for (TfLiteIntArray* dim : allocated_dims_) {
|
||||
TfLiteIntArrayFree(dim);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteIntArray* CreateDimArray(int size,
|
||||
std::initializer_list<int> dimensions) {
|
||||
TfLiteIntArray* dims = TfLiteIntArrayCreate(size);
|
||||
allocated_dims_.push_back(dims);
|
||||
|
||||
int i = 0;
|
||||
for (const int dimension : dimensions) {
|
||||
dims->data[i++] = dimension;
|
||||
}
|
||||
|
||||
return dims;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<TfLiteIntArray*> allocated_dims_;
|
||||
};
|
||||
|
||||
using tflite::delegate::nnapi::ExtractQuantLstmWeightsSubmatrix;
|
||||
|
||||
class ExtractQuantLstmWeightsSubmatrixTest : public DimsAllocatingTest {};
|
||||
|
||||
TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopLeftSubmatrixIsExtracted) {
|
||||
std::vector<uint8_t> weights = {1, 2, 3, 4, 5, //
|
||||
11, 12, 13, 14, 15, //
|
||||
101, 102, 103, 104, 105, //
|
||||
111, 112, 113, 114, 115, //
|
||||
201, 202, 203, 204, 205, //
|
||||
211, 212, 213, 214, 215, //
|
||||
221, 222, 223, 224, 225, //
|
||||
231, 232, 233, 234, 235};
|
||||
const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
|
||||
|
||||
std::vector<uint8_t> submatrix;
|
||||
const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 3});
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
|
||||
0 /* offset_column */, weight_dims,
|
||||
weights.data(), &submatrix);
|
||||
|
||||
EXPECT_THAT(submatrix, ElementsAreArray({1, 2, 3, 11, 12, 13}));
|
||||
}
|
||||
|
||||
TEST_F(ExtractQuantLstmWeightsSubmatrixTest, TopRightSubmatrixIsExtracted) {
|
||||
std::vector<uint8_t> weights = {1, 2, 3, 4, 5, //
|
||||
11, 12, 13, 14, 15, //
|
||||
101, 102, 103, 104, 105, //
|
||||
111, 112, 113, 114, 115, //
|
||||
201, 202, 203, 204, 205, //
|
||||
211, 212, 213, 214, 215, //
|
||||
221, 222, 223, 224, 225, //
|
||||
231, 232, 233, 234, 235};
|
||||
const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
|
||||
|
||||
std::vector<uint8_t> submatrix;
|
||||
const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(submatrix_dims, 0 /* offset_row */,
|
||||
3 /* offset_column */, weight_dims,
|
||||
weights.data(), &submatrix);
|
||||
|
||||
EXPECT_THAT(submatrix, ElementsAreArray({4, 5, 14, 15}));
|
||||
}
|
||||
|
||||
TEST_F(ExtractQuantLstmWeightsSubmatrixTest, RightCentralSubmatrixIsExtracted) {
|
||||
std::vector<uint8_t> weights = {1, 2, 3, 4, 5, //
|
||||
11, 12, 13, 14, 15, //
|
||||
101, 102, 103, 104, 105, //
|
||||
111, 112, 113, 114, 115, //
|
||||
201, 202, 203, 204, 205, //
|
||||
211, 212, 213, 214, 215, //
|
||||
221, 222, 223, 224, 225, //
|
||||
231, 232, 233, 234, 235};
|
||||
const TfLiteIntArray* weight_dims = CreateDimArray(2, {8, 5});
|
||||
|
||||
std::vector<uint8_t> submatrix;
|
||||
const TfLiteIntArray* submatrix_dims = CreateDimArray(2, {2, 2});
|
||||
|
||||
ExtractQuantLstmWeightsSubmatrix(
|
||||
submatrix_dims, 1 * submatrix_dims->data[0] /* offset_row */,
|
||||
3 /* offset_column */, weight_dims, weights.data(), &submatrix);
|
||||
|
||||
EXPECT_THAT(submatrix, ElementsAreArray({104, 105, 114, 115}));
|
||||
}
|
||||
|
||||
using tflite::delegate::nnapi::DecomposeQuantLstmWeightsTensor;
|
||||
|
||||
class QuantLstmWeightDecompTest : public DimsAllocatingTest {
|
||||
protected:
|
||||
QuantLstmWeightDecompTest()
|
||||
: weights_({1, 2, 3, 4, 5, //
|
||||
11, 12, 13, 14, 15, //
|
||||
101, 102, 103, 104, 105, //
|
||||
111, 112, 113, 114, 115, //
|
||||
201, 202, 203, 204, 205, //
|
||||
211, 212, 213, 214, 215, //
|
||||
221, 222, 223, 224, 225, //
|
||||
231, 232, 233, 234, 235}),
|
||||
// Creating the arrays empty, the size is set by the decomposition
|
||||
// function
|
||||
recurrent_to_input_(),
|
||||
input_to_input_(),
|
||||
recurrent_to_cell_(),
|
||||
input_to_cell_(),
|
||||
recurrent_to_forget_(),
|
||||
input_to_forget_(),
|
||||
recurrent_to_output_(),
|
||||
input_to_output_() {
|
||||
weight_dims_ = CreateDimArray(2, {8, 5});
|
||||
}
|
||||
|
||||
const std::vector<uint8_t> weights_;
|
||||
const TfLiteIntArray* weight_dims_;
|
||||
std::vector<uint8_t> recurrent_to_input_;
|
||||
std::vector<uint8_t> input_to_input_;
|
||||
std::vector<uint8_t> recurrent_to_cell_;
|
||||
std::vector<uint8_t> input_to_cell_;
|
||||
std::vector<uint8_t> recurrent_to_forget_;
|
||||
std::vector<uint8_t> input_to_forget_;
|
||||
std::vector<uint8_t> recurrent_to_output_;
|
||||
std::vector<uint8_t> input_to_output_;
|
||||
};
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToInput) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(recurrent_to_input_, ElementsAreArray({1, 2, //
|
||||
11, 12}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractInputToInput) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(input_to_input_, ElementsAreArray({3, 4, 5, //
|
||||
13, 14, 15}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToCell) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(recurrent_to_cell_, ElementsAreArray({101, 102, //
|
||||
111, 112}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractInputToCell) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(input_to_cell_, ElementsAreArray({103, 104, 105, //
|
||||
113, 114, 115}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToForget) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(recurrent_to_forget_, ElementsAreArray({201, 202, //
|
||||
211, 212}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractInputToForget) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(input_to_forget_, ElementsAreArray({203, 204, 205, //
|
||||
213, 214, 215}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractRecurrentToOutput) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(recurrent_to_output_, ElementsAreArray({221, 222, //
|
||||
231, 232}));
|
||||
}
|
||||
|
||||
TEST_F(QuantLstmWeightDecompTest, ExtractInputToOutput) {
|
||||
DecomposeQuantLstmWeightsTensor(
|
||||
weights_.data(), weight_dims_, &recurrent_to_input_, &input_to_input_,
|
||||
&recurrent_to_cell_, &input_to_cell_, &recurrent_to_forget_,
|
||||
&input_to_forget_, &recurrent_to_output_, &input_to_output_);
|
||||
|
||||
EXPECT_THAT(input_to_output_, ElementsAreArray({223, 224, 225, //
|
||||
233, 234, 235}));
|
||||
}
|
||||
|
||||
using tflite::delegate::nnapi::DecomposeBiasTensor;
|
||||
|
||||
TEST(DecomposeBiasTensor, ExtractInputBias) {
|
||||
// clang-format off
|
||||
std::vector<int32_t> biases
|
||||
// inputGateBias
|
||||
{-7876, 13488, -726, 32839,
|
||||
// cellGateBias
|
||||
39481, 48624, 48976, -21419,
|
||||
// forgetGateBias
|
||||
9206, -46884, -11693, -38724,
|
||||
// outputGateBias
|
||||
-58999, -17050, -41852, -40538};
|
||||
// clang-format on
|
||||
|
||||
std::vector<int32_t> input_bias;
|
||||
std::vector<int32_t> cell_bias;
|
||||
std::vector<int32_t> forget_bias;
|
||||
std::vector<int32_t> output_bias;
|
||||
DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
|
||||
&output_bias);
|
||||
|
||||
EXPECT_THAT(input_bias, ElementsAreArray({-7876, 13488, -726, 32839}));
|
||||
}
|
||||
|
||||
TEST(DecomposeBiasTensor, ExtractCellBias) {
|
||||
// clang-format off
|
||||
std::vector<int32_t> biases
|
||||
// inputGateBias
|
||||
{-7876, 13488, -726, 32839,
|
||||
// cellGateBias
|
||||
39481, 48624, 48976, -21419,
|
||||
// forgetGateBias
|
||||
9206, -46884, -11693, -38724,
|
||||
// outputGateBias
|
||||
-58999, -17050, -41852, -40538};
|
||||
// clang-format on
|
||||
|
||||
std::vector<int32_t> input_bias;
|
||||
std::vector<int32_t> cell_bias;
|
||||
std::vector<int32_t> forget_bias;
|
||||
std::vector<int32_t> output_bias;
|
||||
DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
|
||||
&output_bias);
|
||||
|
||||
EXPECT_THAT(cell_bias, ElementsAreArray({39481, 48624, 48976, -21419}));
|
||||
}
|
||||
|
||||
TEST(DecomposeBiasTensor, ExtractForgetBias) {
|
||||
// clang-format off
|
||||
std::vector<int32_t> biases
|
||||
// inputGateBias
|
||||
{-7876, 13488, -726, 32839,
|
||||
// cellGateBias
|
||||
39481, 48624, 48976, -21419,
|
||||
// forgetGateBias
|
||||
9206, -46884, -11693, -38724,
|
||||
// outputGateBias
|
||||
-58999, -17050, -41852, -40538};
|
||||
// clang-format on
|
||||
|
||||
std::vector<int32_t> input_bias;
|
||||
std::vector<int32_t> cell_bias;
|
||||
std::vector<int32_t> forget_bias;
|
||||
std::vector<int32_t> output_bias;
|
||||
DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
|
||||
&output_bias);
|
||||
|
||||
EXPECT_THAT(forget_bias, ElementsAreArray({9206, -46884, -11693, -38724}));
|
||||
}
|
||||
|
||||
TEST(DecomposeBiasTensor, ExtractOutputBias) {
|
||||
// clang-format off
|
||||
std::vector<int32_t> biases
|
||||
// inputGateBias
|
||||
{-7876, 13488, -726, 32839,
|
||||
// cellGateBias
|
||||
39481, 48624, 48976, -21419,
|
||||
// forgetGateBias
|
||||
9206, -46884, -11693, -38724,
|
||||
// outputGateBias
|
||||
-58999, -17050, -41852, -40538};
|
||||
// clang-format on
|
||||
|
||||
std::vector<int32_t> input_bias;
|
||||
std::vector<int32_t> cell_bias;
|
||||
std::vector<int32_t> forget_bias;
|
||||
std::vector<int32_t> output_bias;
|
||||
DecomposeBiasTensor(biases.data(), 4, &input_bias, &cell_bias, &forget_bias,
|
||||
&output_bias);
|
||||
|
||||
EXPECT_THAT(output_bias, ElementsAreArray({-58999, -17050, -41852, -40538}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -1836,3 +1836,18 @@ cc_test(
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "quant_basic_lstm_test",
|
||||
size = "small",
|
||||
srcs = ["quant_basic_lstm_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":kernel_util",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -54,14 +54,18 @@ inline int NumIntermediates(const TfLiteNode* node) {
|
||||
return node->intermediates->size;
|
||||
}
|
||||
|
||||
inline int64_t NumElements(const TfLiteTensor* t) {
|
||||
inline int64_t NumElements(const TfLiteIntArray* dims) {
|
||||
int64_t count = 1;
|
||||
for (int i = 0; i < NumDimensions(t); ++i) {
|
||||
count *= SizeOfDimension(t, i);
|
||||
for (int i = 0; i < dims->size; ++i) {
|
||||
count *= dims->data[i];
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
inline int64_t NumElements(const TfLiteTensor* t) {
|
||||
return NumElements(t->dims);
|
||||
}
|
||||
|
||||
inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
|
||||
const TfLiteNode* node,
|
||||
int index) {
|
||||
|
230
tensorflow/lite/kernels/quant_basic_lstm_test.cc
Normal file
230
tensorflow/lite/kernels/quant_basic_lstm_test.cc
Normal file
@ -0,0 +1,230 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class QuantizedLSTMOpModel : public SingleOpModel {
|
||||
public:
|
||||
QuantizedLSTMOpModel(int numBatches, int inputSize, float weightsScale,
|
||||
int32_t weightsZeroPoint, int outputSize,
|
||||
std::initializer_list<uint8_t> weights,
|
||||
std::initializer_list<int32_t> biases) {
|
||||
std::vector<uint32_t> inputs;
|
||||
|
||||
input_size_ = inputSize;
|
||||
output_size_ = outputSize;
|
||||
|
||||
std::vector<int> input_shape{numBatches, inputSize};
|
||||
std::vector<int> output_shape{numBatches, outputSize};
|
||||
std::vector<int> weight_shape{4 * outputSize, outputSize + inputSize};
|
||||
std::vector<int> state_shape{numBatches, outputSize};
|
||||
std::vector<int> bias_shape{4 * outputSize};
|
||||
|
||||
input_ =
|
||||
AddInput({TensorType_UINT8, input_shape, 0.0f, 0.0f, 1. / 128., 128});
|
||||
prev_output_ =
|
||||
AddInput({TensorType_UINT8, output_shape, 0.0f, 0.0f, 1. / 128., 128});
|
||||
// Biases and Weights have to be constant in order to allow NNAPI
|
||||
// delegation
|
||||
weights_ = AddConstInput<uint8_t>({TensorType_UINT8, weight_shape, 0.0f,
|
||||
0.0f, weightsScale, weightsZeroPoint},
|
||||
weights);
|
||||
biases_ = AddConstInput<int32_t>(
|
||||
{TensorType_INT32, bias_shape, 0.0f, 0.0f, weightsScale / 128, 0},
|
||||
biases);
|
||||
prev_cell_state_ =
|
||||
AddInput({TensorType_INT16, state_shape, 0.0f, 0.0f, 1. / 2048., 0});
|
||||
|
||||
output_ =
|
||||
AddOutput({TensorType_UINT8, output_shape, 0.0f, 0.0f, 1. / 128., 128});
|
||||
cell_state_out_ =
|
||||
AddOutput({TensorType_INT16, state_shape, 0.0f, 0.0f, 1. / 2048., 0});
|
||||
output_concat_temp_ =
|
||||
AddOutput({TensorType_UINT8, output_shape, 0.0f, 0.0f, 1. / 128., 128});
|
||||
output_activation_temp_ =
|
||||
AddOutput({TensorType_INT16, output_shape, 0.0f, 0.0f, 1. / 128., 128});
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
||||
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, 0.0,
|
||||
0.0, LSTMKernelType_BASIC)
|
||||
.Union());
|
||||
|
||||
BuildInterpreter({GetShape(input_), GetShape(prev_output_),
|
||||
GetShape(weights_), GetShape(biases_),
|
||||
GetShape(prev_cell_state_)});
|
||||
|
||||
// init feedback inputs to zero
|
||||
std::vector<int16_t> initial_state(GetTensorSize(cell_state_out_), 0);
|
||||
PopulateTensor(prev_cell_state_, initial_state);
|
||||
std::vector<uint8_t> initial_prev_output(GetTensorSize(output_), 0);
|
||||
PopulateTensor(prev_output_, initial_prev_output);
|
||||
}
|
||||
|
||||
int inputSize() { return input_size_; }
|
||||
|
||||
int outputSize() { return output_size_; }
|
||||
|
||||
void setInput(const std::vector<uint8_t>& input) {
|
||||
PopulateTensor(input_, input);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> getOutput() { return ExtractVector<uint8_t>(output_); }
|
||||
|
||||
private:
|
||||
// Inputs
|
||||
int input_;
|
||||
int weights_;
|
||||
int biases_;
|
||||
int prev_cell_state_;
|
||||
int prev_output_;
|
||||
// Outputs
|
||||
int cell_state_out_;
|
||||
int output_;
|
||||
int output_concat_temp_;
|
||||
int output_activation_temp_;
|
||||
|
||||
int input_size_;
|
||||
int output_size_;
|
||||
};
|
||||
|
||||
class QuantizedLstmTest : public ::testing::Test {
|
||||
protected:
|
||||
void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
|
||||
const std::vector<std::vector<uint8_t>>& output,
|
||||
QuantizedLSTMOpModel* lstm) {
|
||||
const int numBatches = input.size();
|
||||
ASSERT_GT(numBatches, 0);
|
||||
const int inputSize = lstm->inputSize();
|
||||
ASSERT_GT(inputSize, 0);
|
||||
const int inputSequenceSize = input[0].size() / inputSize;
|
||||
ASSERT_GT(inputSequenceSize, 0);
|
||||
for (int i = 0; i < inputSequenceSize; ++i) {
|
||||
std::vector<uint8_t> inputStep;
|
||||
for (int b = 0; b < numBatches; ++b) {
|
||||
const uint8_t* batchStart = input[b].data() + i * inputSize;
|
||||
const uint8_t* batchEnd = batchStart + inputSize;
|
||||
inputStep.insert(inputStep.end(), batchStart, batchEnd);
|
||||
}
|
||||
lstm->setInput(inputStep);
|
||||
lstm->Invoke();
|
||||
|
||||
const int outputSize = lstm->outputSize();
|
||||
std::vector<float> expected;
|
||||
for (int b = 0; b < numBatches; ++b) {
|
||||
const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
|
||||
const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
|
||||
expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
|
||||
}
|
||||
EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Inputs and weights in this test are random and the test only checks that the
|
||||
// outputs are equal to outputs obtained from running TF Lite version of
|
||||
// quantized LSTM on the same inputs.
|
||||
TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
|
||||
const int numBatches = 2;
|
||||
const int inputSize = 2;
|
||||
const int outputSize = 4;
|
||||
|
||||
float weightsScale = 0.00408021;
|
||||
int weightsZeroPoint = 100;
|
||||
|
||||
QuantizedLSTMOpModel lstm(
|
||||
numBatches, inputSize, weightsScale, weightsZeroPoint, outputSize,
|
||||
|
||||
// This data are copied from QuantizedLSTMTest.cpp in NNAPI source code
|
||||
// I have to recompose the weight matrix before passing it to the model
|
||||
|
||||
// recurrentToInputWeights inputToInputWeights
|
||||
{254, 206, 77, 168, 146, 250, 71, 20, 215, 6, 235, 171, 223, 7, 118, 225,
|
||||
10, 218, 59, 130, 174, 26, 171, 108,
|
||||
|
||||
// recurrentToCellWeights inputToCellWeights
|
||||
172, 60, 205, 65, 133, 34, 14, 0, 140, 168, 29, 49, 240, 223, 133, 56,
|
||||
206, 109, 142, 64, 246, 216, 54, 183,
|
||||
|
||||
// recurrentToForgetWeights inputToForgetWeights
|
||||
137, 240, 103, 52, 24, 50, 68, 51, 237, 112, 132, 179, 0, 220, 89, 23,
|
||||
158, 110, 69, 4, 207, 253, 3, 169,
|
||||
|
||||
// recurrentToOutputWeights inputToOutputWeights
|
||||
106, 214, 67, 23, 195, 187, 59, 158, 45, 3, 11, 99, 119, 132, 49, 205,
|
||||
109, 10, 129, 218, 11, 98, 218, 48},
|
||||
|
||||
// inputGateBias
|
||||
{-7876, 13488, -726, 32839,
|
||||
// cellGateBias
|
||||
39481, 48624, 48976, -21419,
|
||||
// forgetGateBias
|
||||
9206, -46884, -11693, -38724,
|
||||
// outputGateBias
|
||||
-58999, -17050, -41852, -40538});
|
||||
// clang-format on
|
||||
|
||||
// LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
|
||||
std::vector<std::vector<uint8_t>> lstmInput;
|
||||
// clang-format off
|
||||
lstmInput = {{154, 166,
|
||||
166, 179,
|
||||
141, 141},
|
||||
{100, 200,
|
||||
50, 150,
|
||||
111, 222}};
|
||||
// clang-format on
|
||||
|
||||
// LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
|
||||
std::vector<std::vector<uint8_t>> lstmGoldenOutput;
|
||||
/*
|
||||
This is the output used in NNAPI's QuantizedLSTMTest.cpp
|
||||
I get slightly different values that are consistent running with or
|
||||
without acceleration
|
||||
|
||||
lstmGoldenOutput = {{136, 150, 140, 115,
|
||||
140, 151, 146, 112,
|
||||
139, 153, 146, 114},
|
||||
{135, 152, 138, 112,
|
||||
136, 156, 142, 112,
|
||||
141, 154, 146, 108}};
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
lstmGoldenOutput = {{131, 152, 136, 109,
|
||||
138, 150, 145, 111,
|
||||
139, 152, 146, 113},
|
||||
{131, 153, 135, 107,
|
||||
134, 154, 140, 111,
|
||||
140, 154, 145, 108}};
|
||||
// clang-format on
|
||||
VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -41,6 +41,7 @@ enum {
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
|
||||
ANEURALNETWORKS_BOOL = 6,
|
||||
ANEURALNETWORKS_TENSOR_BOOL8 = 9,
|
||||
ANEURALNETWORKS_TENSOR_QUANT16_SYMM = 7,
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL = 11,
|
||||
ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13,
|
||||
};
|
||||
@ -115,6 +116,7 @@ enum {
|
||||
ANEURALNETWORKS_POW = 70,
|
||||
ANEURALNETWORKS_PRELU = 71,
|
||||
ANEURALNETWORKS_QUANTIZE = 72,
|
||||
ANEURALNETWORKS_QUANTIZED_16BIT_LSTM = 73,
|
||||
ANEURALNETWORKS_REDUCE_ANY = 76,
|
||||
ANEURALNETWORKS_REDUCE_MAX = 77,
|
||||
ANEURALNETWORKS_REDUCE_MIN = 78,
|
||||
|
@ -164,6 +164,7 @@ endif
|
||||
ifeq ($(BUILD_WITH_NNAPI),true)
|
||||
CORE_CC_ALL_SRCS += tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
|
||||
CORE_CC_ALL_SRCS += tensorflow/lite/nnapi/nnapi_implementation.cc
|
||||
CORE_CC_ALL_SRCS += tensorflow/lite/nnapi/quant_lstm_sup.cc
|
||||
else
|
||||
CORE_CC_ALL_SRCS += tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc
|
||||
CORE_CC_ALL_SRCS += tensorflow/lite/nnapi/nnapi_implementation_disabled.cc
|
||||
|
Loading…
Reference in New Issue
Block a user