Adds asymmetric quantized inputs for hybrid ops in future models.

PiperOrigin-RevId: 304559648
Change-Id: I8028ae6f65308c9b9fa928b8d755919af1faa7be
This commit is contained in:
David Rim 2020-04-03 00:25:49 -07:00 committed by TensorFlower Gardener
parent 0752177439
commit bb130adb39
29 changed files with 1721 additions and 491 deletions

View File

@ -124,21 +124,33 @@ typedef struct {
typedef struct { typedef struct {
int rank; int rank;
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
// Parameter for SVDF version 4.
bool asymmetric_quantize_inputs;
} TfLiteSVDFParams; } TfLiteSVDFParams;
typedef struct { typedef struct {
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
// Parameter for RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteRNNParams; } TfLiteRNNParams;
typedef struct { typedef struct {
bool time_major; bool time_major;
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
// Parameter for Sequence RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteSequenceRNNParams; } TfLiteSequenceRNNParams;
typedef struct { typedef struct {
bool time_major; bool time_major;
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
bool merge_outputs; bool merge_outputs;
// Parameter for Bidirectional RNN verison 3.
bool asymmetric_quantize_inputs;
} TfLiteBidirectionalSequenceRNNParams; } TfLiteBidirectionalSequenceRNNParams;
typedef enum { typedef enum {
@ -158,6 +170,11 @@ typedef struct {
// tensors are the same. Furthermore, all but the last dimension of the input // tensors are the same. Furthermore, all but the last dimension of the input
// and output shapes will be equal. // and output shapes will be equal.
bool keep_num_dims; bool keep_num_dims;
// Parameters for FullyConnected version 7 or above.
// If set to true and the weights are quantized, then non constant inputs
// are quantized at evaluation time with asymmetric quantization.
bool asymmetric_quantize_inputs;
} TfLiteFullyConnectedParams; } TfLiteFullyConnectedParams;
typedef enum { typedef enum {
@ -228,6 +245,9 @@ typedef struct {
// Parameters for LSTM version 2. // Parameters for LSTM version 2.
// kTfLiteLSTMBasicKernel is only supported in version 2 or above. // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
TfLiteLSTMKernelType kernel_type; TfLiteLSTMKernelType kernel_type;
// Parameters for LSTM version 4.
bool asymmetric_quantize_inputs;
} TfLiteLSTMParams; } TfLiteLSTMParams;
typedef struct { typedef struct {
@ -238,6 +258,9 @@ typedef struct {
// If set to true then the first dimension is time, otherwise batch. // If set to true then the first dimension is time, otherwise batch.
bool time_major; bool time_major;
// Parameter for unidirectional sequence RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteUnidirectionalSequenceLSTMParams; } TfLiteUnidirectionalSequenceLSTMParams;
typedef struct { typedef struct {
@ -253,6 +276,10 @@ typedef struct {
// Parameters supported by version 2: // Parameters supported by version 2:
// If set to true then the first dimension is time, otherwise batch. // If set to true then the first dimension is time, otherwise batch.
bool time_major; bool time_major;
// Parameters supported by version 4:
// If set to true, then hybrid ops use asymmetric quantization for inputs.
bool asymmetric_quantize_inputs;
} TfLiteBidirectionalSequenceLSTMParams; } TfLiteBidirectionalSequenceLSTMParams;
typedef struct { typedef struct {

View File

@ -269,6 +269,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->rank = svdf_params->rank(); params->rank = svdf_params->rank();
params->activation = params->activation =
parse_activation(svdf_params->fused_activation_function()); parse_activation(svdf_params->fused_activation_function());
params->asymmetric_quantize_inputs =
svdf_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -280,6 +282,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->activation = params->activation =
parse_activation(sequence_rnn_params->fused_activation_function()); parse_activation(sequence_rnn_params->fused_activation_function());
params->time_major = sequence_rnn_params->time_major(); params->time_major = sequence_rnn_params->time_major();
params->asymmetric_quantize_inputs =
sequence_rnn_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -293,6 +297,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
bidi_sequence_rnn_params->fused_activation_function()); bidi_sequence_rnn_params->fused_activation_function());
params->time_major = bidi_sequence_rnn_params->time_major(); params->time_major = bidi_sequence_rnn_params->time_major();
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
params->asymmetric_quantize_inputs =
bidi_sequence_rnn_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -302,6 +308,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
params->activation = params->activation =
parse_activation(rnn_params->fused_activation_function()); parse_activation(rnn_params->fused_activation_function());
params->asymmetric_quantize_inputs =
rnn_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -323,6 +331,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->activation = parse_activation( params->activation = parse_activation(
fully_connected_params->fused_activation_function()); fully_connected_params->fused_activation_function());
params->keep_num_dims = fully_connected_params->keep_num_dims(); params->keep_num_dims = fully_connected_params->keep_num_dims();
params->asymmetric_quantize_inputs =
fully_connected_params->asymmetric_quantize_inputs();
switch (fully_connected_params->weights_format()) { switch (fully_connected_params->weights_format()) {
case FullyConnectedOptionsWeightsFormat_DEFAULT: case FullyConnectedOptionsWeightsFormat_DEFAULT:
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
@ -440,6 +450,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
lstm_params->kernel_type()); lstm_params->kernel_type());
return kTfLiteError; return kTfLiteError;
} }
params->asymmetric_quantize_inputs =
lstm_params->asymmetric_quantize_inputs();
} else { } else {
TF_LITE_REPORT_ERROR(error_reporter, TF_LITE_REPORT_ERROR(error_reporter,
"No valid LSTM builtin options exist"); "No valid LSTM builtin options exist");
@ -458,6 +470,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->cell_clip = seq_lstm_params->cell_clip(); params->cell_clip = seq_lstm_params->cell_clip();
params->proj_clip = seq_lstm_params->proj_clip(); params->proj_clip = seq_lstm_params->proj_clip();
params->time_major = seq_lstm_params->time_major(); params->time_major = seq_lstm_params->time_major();
params->asymmetric_quantize_inputs =
seq_lstm_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -473,6 +487,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->proj_clip = bidi_lstm_params->proj_clip(); params->proj_clip = bidi_lstm_params->proj_clip();
params->merge_outputs = bidi_lstm_params->merge_outputs(); params->merge_outputs = bidi_lstm_params->merge_outputs();
params->time_major = bidi_lstm_params->time_major(); params->time_major = bidi_lstm_params->time_major();
params->asymmetric_quantize_inputs =
bidi_lstm_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;

View File

@ -26,6 +26,15 @@ namespace ops {
namespace builtin { namespace builtin {
namespace rnn { namespace rnn {
namespace {
struct OpData {
int scratch_tensor_index;
bool compute_row_sums = false;
};
} // namespace
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1; constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2; constexpr int kRecurrentWeightsTensor = 2;
@ -36,13 +45,14 @@ constexpr int kHiddenStateTensor = 4;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int; auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); context->AddTensors(context, /*tensors_to_add=*/6,
return scratch_tensor_index; &op_data->scratch_tensor_index);
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -89,10 +99,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store quantized values of input and // Allocate temporary tensors to store quantized values of input and
// hidden_state tensors. // hidden_state tensors.
if (is_hybrid) { if (is_hybrid) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); auto* op_data = reinterpret_cast<OpData*>(node->user_data);
op_data->compute_row_sums = true;
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(3); node->temporaries = TfLiteIntArrayCreate(6);
node->temporaries->data[0] = *scratch_tensor_index; node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = input_weights->type; input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
@ -101,7 +112,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
input_quantized_size)); input_quantized_size));
} }
node->temporaries->data[1] = *scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized = TfLiteTensor* hidden_state_quantized =
GetTemporary(context, node, /*index=*/1); GetTemporary(context, node, /*index=*/1);
hidden_state_quantized->type = input_weights->type; hidden_state_quantized->type = input_weights->type;
@ -114,7 +125,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized, context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size)); hidden_state_quantized_size));
} }
node->temporaries->data[2] = *scratch_tensor_index + 2; node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
@ -125,8 +136,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = accum_scratch_dims[0];
accum_scratch_size->data[1] = accum_scratch_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -165,7 +211,9 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
TfLiteTensor* input_scratch, TfLiteTensor* input_scratch,
TfLiteTensor* hidden_state_scratch, TfLiteTensor* hidden_state_scratch,
TfLiteTensor* scaling_factors, TfLiteTensor* scaling_factors,
TfLiteTensor* hidden_state, TfLiteTensor* output) { TfLiteTensor* hidden_state, TfLiteTensor* output,
TfLiteTensor* zero_points, TfLiteTensor* accum_scratch,
TfLiteTensor* row_sums, bool* compute_row_sums) {
const int batch_size = input->dims->data[0]; const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0]; const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1]; const int input_size = input->dims->data[1];
@ -190,26 +238,34 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
int8_t* quantized_hidden_state_ptr = int8_t* quantized_hidden_state_ptr =
GetTensorData<int8_t>(hidden_state_scratch); GetTensorData<int8_t>(hidden_state_scratch);
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
kernel_utils::RnnBatchStep( kernel_utils::RnnBatchStep(
input_ptr_batch, input_weights_ptr, input_weights_scale, input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, output_batch_leading_dim, params->activation, num_units, batch_size, output_batch_leading_dim, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
hidden_state_ptr_batch, output_ptr_batch); hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr,
row_sums_ptr, compute_row_sums);
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights = const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor); GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
TfLiteTensor* hidden_state = TfLiteTensor* hidden_state =
GetVariableInput(context, node, kHiddenStateTensor); &context->tensors[node->inputs->data[kHiddenStateTensor]];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// We already checked that weight types are consistent, so branch on one. // We already checked that weight types are consistent, so branch on one.
@ -223,9 +279,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
return EvalHybrid(input, input_weights, recurrent_weights, bias, params, return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized, input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output); scaling_factors, hidden_state, output, zero_points,
accum_scratch, row_sums, &op_data->compute_row_sums);
} }
default: default:
context->ReportError(context, "Type %d not currently supported.", context->ReportError(context, "Type %d not currently supported.",

View File

@ -175,7 +175,8 @@ class RNNOpModel : public SingleOpModel {
public: public:
RNNOpModel(int batches, int units, int size, RNNOpModel(int batches, int units, int size,
const TensorType& weights = TensorType_FLOAT32, const TensorType& weights = TensorType_FLOAT32,
const TensorType& recurrent_weights = TensorType_FLOAT32) const TensorType& recurrent_weights = TensorType_FLOAT32,
bool asymmetric_quantize_inputs = false)
: batches_(batches), units_(units), input_size_(size) { : batches_(batches), units_(units), input_size_(size) {
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
weights_ = AddInput(weights); weights_ = AddInput(weights);
@ -183,9 +184,10 @@ class RNNOpModel : public SingleOpModel {
bias_ = AddInput(TensorType_FLOAT32); bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddInput(TensorType_FLOAT32, true); hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp( SetBuiltinOp(BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
BuiltinOperator_RNN, BuiltinOptions_RNNOptions, CreateRNNOptions(builder_, ActivationFunctionType_RELU,
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); asymmetric_quantize_inputs)
.Union());
BuildInterpreter({{batches_, input_size_}, // input tensor BuildInterpreter({{batches_, input_size_}, // input tensor
{units_, input_size_}, // weights tensor {units_, input_size_}, // weights tensor
{units_, units_}, // recurrent weights tensor {units_, units_}, // recurrent weights tensor
@ -233,8 +235,10 @@ class RNNOpModel : public SingleOpModel {
// The hybrid model has quantized weights and recurrent_weights. // The hybrid model has quantized weights and recurrent_weights.
class HybridRNNOpModel : public RNNOpModel { class HybridRNNOpModel : public RNNOpModel {
public: public:
HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type) HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type,
: RNNOpModel(batches, units, size, tensor_type, tensor_type) { bool asymmetric_quantize_inputs)
: RNNOpModel(batches, units, size, tensor_type, tensor_type,
asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -282,8 +286,10 @@ TEST(RnnOpTest, BlackBoxTest) {
} }
} }
TEST(HybridRnnOpTest, BlackBoxTestUint8) { class HybridRnnOpTest : public ::testing::TestWithParam<bool> {};
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8);
TEST_P(HybridRnnOpTest, BlackBoxTestUint8) {
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8, GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -310,8 +316,8 @@ TEST(HybridRnnOpTest, BlackBoxTestUint8) {
} }
} }
TEST(HybridRnnOpTest, BlackBoxTestInt8) { TEST_P(HybridRnnOpTest, BlackBoxTestInt8) {
HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8); HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8, GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -338,5 +344,8 @@ TEST(HybridRnnOpTest, BlackBoxTestInt8) {
} }
} }
INSTANTIATE_TEST_SUITE_P(HybridRnnOpTest, HybridRnnOpTest,
::testing::ValuesIn({false, true}));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -139,18 +139,28 @@ enum TemporaryTensor {
kProductScalingFactors = 8, kProductScalingFactors = 8,
kRecoveredCellWeights = 9, kRecoveredCellWeights = 9,
kAccumScratchBuffer = 10, kAccumScratchBuffer = 10,
kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input. kZeroPoints = 11,
kNumTemporaryTensors kFwRowSums = 12,
kBwRowSums = 13,
kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input.
kNumTemporaryTensors = 15
};
struct OpData {
int scratch_tensor_index;
bool compute_fw_row_sums = false;
bool compute_bw_row_sums = false;
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int; auto* op_data = new OpData();
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); context->AddTensors(context, kNumTemporaryTensors,
return scratch_tensor_index; &op_data->scratch_tensor_index);
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
// Check that input tensor dimensions matches with each other. // Check that input tensor dimensions matches with each other.
@ -385,7 +395,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Resize the output and scratch tensors based on the sizes of the input // Resize the output and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other. // tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); auto* op_data = reinterpret_cast<OpData*>(node->user_data);
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
node->builtin_data); node->builtin_data);
@ -522,7 +532,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
} }
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; node->temporaries->data[kFwScratchBuffer] = op_data->scratch_tensor_index;
TfLiteTensor* fw_scratch_buffer = TfLiteTensor* fw_scratch_buffer =
GetTemporary(context, node, kFwScratchBuffer); GetTemporary(context, node, kFwScratchBuffer);
fw_scratch_buffer->type = input->type; fw_scratch_buffer->type = input->type;
@ -581,7 +591,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
node->temporaries->data[kBwScratchBuffer] = node->temporaries->data[kBwScratchBuffer] =
*(scratch_tensor_index) + kBwScratchBuffer; op_data->scratch_tensor_index + kBwScratchBuffer;
TfLiteTensor* bw_scratch_buffer = TfLiteTensor* bw_scratch_buffer =
GetTemporary(context, node, kBwScratchBuffer); GetTemporary(context, node, kBwScratchBuffer);
bw_scratch_buffer->type = input->type; bw_scratch_buffer->type = input->type;
@ -606,10 +616,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
bw_scratch_buffer_size)); bw_scratch_buffer_size));
if (is_hybrid_op) { if (is_hybrid_op) {
// Compute the row sums for cached zero_point offset calculation.
op_data->compute_fw_row_sums = true;
op_data->compute_bw_row_sums = true;
// Allocate temporary tensors to store quantized values of input, aux_input // Allocate temporary tensors to store quantized values of input, aux_input
// (if present), activation_state and cell_state tensors. // (if present), activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
*scratch_tensor_index + kInputQuantized; op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized); GetTemporary(context, node, kInputQuantized);
input_quantized->type = fw_input_to_output_weights->type; input_quantized->type = fw_input_to_output_weights->type;
@ -621,7 +634,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kFwActivationStateQuantized] = node->temporaries->data[kFwActivationStateQuantized] =
*scratch_tensor_index + kFwActivationStateQuantized; op_data->scratch_tensor_index + kFwActivationStateQuantized;
TfLiteTensor* fw_activation_state_quantized = TfLiteTensor* fw_activation_state_quantized =
GetTemporary(context, node, kFwActivationStateQuantized); GetTemporary(context, node, kFwActivationStateQuantized);
fw_activation_state_quantized->type = fw_input_to_output_weights->type; fw_activation_state_quantized->type = fw_input_to_output_weights->type;
@ -635,7 +648,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_activation_state_quantized_size)); fw_activation_state_quantized_size));
} }
node->temporaries->data[kBwActivationStateQuantized] = node->temporaries->data[kBwActivationStateQuantized] =
*scratch_tensor_index + kBwActivationStateQuantized; op_data->scratch_tensor_index + kBwActivationStateQuantized;
TfLiteTensor* bw_activation_state_quantized = TfLiteTensor* bw_activation_state_quantized =
GetTemporary(context, node, kBwActivationStateQuantized); GetTemporary(context, node, kBwActivationStateQuantized);
bw_activation_state_quantized->type = fw_input_to_output_weights->type; bw_activation_state_quantized->type = fw_input_to_output_weights->type;
@ -649,7 +662,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_activation_state_quantized_size)); bw_activation_state_quantized_size));
} }
node->temporaries->data[kFwCellStateQuantized] = node->temporaries->data[kFwCellStateQuantized] =
*scratch_tensor_index + kFwCellStateQuantized; op_data->scratch_tensor_index + kFwCellStateQuantized;
TfLiteTensor* fw_cell_state_quantized = TfLiteTensor* fw_cell_state_quantized =
GetTemporary(context, node, kFwCellStateQuantized); GetTemporary(context, node, kFwCellStateQuantized);
fw_cell_state_quantized->type = fw_input_to_output_weights->type; fw_cell_state_quantized->type = fw_input_to_output_weights->type;
@ -663,7 +676,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_cell_state_quantized_size)); fw_cell_state_quantized_size));
} }
node->temporaries->data[kBwCellStateQuantized] = node->temporaries->data[kBwCellStateQuantized] =
*scratch_tensor_index + kBwCellStateQuantized; op_data->scratch_tensor_index + kBwCellStateQuantized;
TfLiteTensor* bw_cell_state_quantized = TfLiteTensor* bw_cell_state_quantized =
GetTemporary(context, node, kBwCellStateQuantized); GetTemporary(context, node, kBwCellStateQuantized);
bw_cell_state_quantized->type = fw_input_to_output_weights->type; bw_cell_state_quantized->type = fw_input_to_output_weights->type;
@ -683,7 +696,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// different matrices (which requires multiplying the scaling factors with // different matrices (which requires multiplying the scaling factors with
// the scaling factor of the matrix). // the scaling factor of the matrix).
node->temporaries->data[kScalingFactors] = node->temporaries->data[kScalingFactors] =
*scratch_tensor_index + kScalingFactors; op_data->scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors); GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
@ -696,7 +709,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[kProductScalingFactors] = node->temporaries->data[kProductScalingFactors] =
*scratch_tensor_index + kProductScalingFactors; op_data->scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors = TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors); GetTemporary(context, node, kProductScalingFactors);
prod_scaling_factors->type = kTfLiteFloat32; prod_scaling_factors->type = kTfLiteFloat32;
@ -713,7 +726,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the recovered cell weights. Since // Allocate a temporary tensor to store the recovered cell weights. Since
// this is used for diagonal matrices, only need to store n_cell values. // this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] = node->temporaries->data[kRecoveredCellWeights] =
*scratch_tensor_index + kRecoveredCellWeights; op_data->scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights = TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights); GetTemporary(context, node, kRecoveredCellWeights);
recovered_cell_weights->type = kTfLiteFloat32; recovered_cell_weights->type = kTfLiteFloat32;
@ -730,7 +743,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the accumulated int32 values. // Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratchBuffer] = node->temporaries->data[kAccumScratchBuffer] =
*scratch_tensor_index + kAccumScratchBuffer; op_data->scratch_tensor_index + kAccumScratchBuffer;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer); GetTemporary(context, node, kAccumScratchBuffer);
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
@ -750,11 +763,72 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
// Allocate temporary tensors for storing zero-points.
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
// Allocate temporary tensors for caching row sums for hybrid zero-point
// calculations.
int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
if (has_aux_input) {
fw_row_sums_rows += fw_use_cifg ? 3 : 4;
}
const TfLiteTensor* fw_projection_weights =
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
if (fw_projection_weights != nullptr) {
fw_row_sums_rows += ceil(n_fw_output / n_fw_cell);
}
node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
fw_hybrid_scratch_size));
}
int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
if (has_aux_input) {
bw_row_sums_rows += bw_use_cifg ? 3 : 4;
}
const TfLiteTensor* bw_projection_weights =
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
if (bw_projection_weights != nullptr) {
bw_row_sums_rows += ceil(n_bw_output / n_bw_cell);
}
node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
bw_row_sums_size));
}
// Only allocate a temporary tensor for quantized auxiliary input if we are // Only allocate a temporary tensor for quantized auxiliary input if we are
// actually going to use it. // actually going to use it.
if (has_aux_input) { if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] = node->temporaries->data[kAuxInputQuantized] =
*scratch_tensor_index + kAuxInputQuantized; op_data->scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
GetTemporary(context, node, kAuxInputQuantized); GetTemporary(context, node, kAuxInputQuantized);
aux_input_quantized->type = fw_input_to_output_weights->type; aux_input_quantized->type = fw_input_to_output_weights->type;
@ -775,7 +849,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
node->builtin_data); node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
// Input tensor. // Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@ -909,7 +983,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Populate a TfLiteLSTMParams struct for the evaluation functions. // Populate a TfLiteLSTMParams struct for the evaluation functions.
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
params->proj_clip, kTfLiteLSTMFullKernel}; params->proj_clip, kTfLiteLSTMFullKernel,
params->asymmetric_quantize_inputs};
const int bw_output_offset = const int bw_output_offset =
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
@ -1003,7 +1078,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
: nullptr; : nullptr;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer); GetTemporary(context, node, kAccumScratchBuffer);
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
const int fw_row_sums_size = fw_row_sums->dims->data[0];
const int bw_row_sums_size = bw_row_sums->dims->data[0];
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights, input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights, fw_input_to_cell_weights, fw_input_to_output_weights,
@ -1025,6 +1104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recovered_cell_weights, input_quantized, aux_input_quantized, recovered_cell_weights, input_quantized, aux_input_quantized,
fw_activation_state_quantized, fw_cell_state_quantized, fw_activation_state_quantized, fw_cell_state_quantized,
fw_activation_state, fw_cell_state, accum_scratch, fw_output, fw_activation_state, fw_cell_state, accum_scratch, fw_output,
zero_points, fw_row_sums, fw_row_sums_size,
&op_data->compute_fw_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, fw_pass_status); TF_LITE_ENSURE_OK(context, fw_pass_status);
@ -1049,6 +1130,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recovered_cell_weights, input_quantized, aux_input_quantized, recovered_cell_weights, input_quantized, aux_input_quantized,
bw_activation_state_quantized, bw_cell_state_quantized, bw_activation_state_quantized, bw_cell_state_quantized,
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
zero_points, bw_row_sums, bw_row_sums_size,
&op_data->compute_bw_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, bw_pass_status); TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk; return kTfLiteOk;

View File

@ -40,7 +40,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bool use_projection_bias, bool merge_outputs, bool use_projection_bias, bool merge_outputs,
bool use_aux_input, float cell_clip, float proj_clip, bool use_aux_input, float cell_clip, float proj_clip,
bool quantize_weights, bool time_major, bool quantize_weights, bool time_major,
const std::vector<std::vector<int>>& input_shapes) const std::vector<std::vector<int>>& input_shapes,
bool asymmetric_quantize_inputs = false)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_fw_cell_(n_cell), n_fw_cell_(n_cell),
@ -207,12 +208,13 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_aux_input_to_output_weights_ = AddNullInput(); bw_aux_input_to_output_weights_ = AddNullInput();
} }
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, SetBuiltinOp(
BuiltinOptions_BidirectionalSequenceLSTMOptions, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
CreateBidirectionalSequenceLSTMOptions( BuiltinOptions_BidirectionalSequenceLSTMOptions,
builder_, ActivationFunctionType_TANH, cell_clip, CreateBidirectionalSequenceLSTMOptions(
proj_clip, merge_outputs, time_major) builder_, ActivationFunctionType_TANH, cell_clip, proj_clip,
.Union()); merge_outputs, time_major, asymmetric_quantize_inputs)
.Union());
BuildInterpreter(input_shapes); BuildInterpreter(input_shapes);
} }
@ -424,11 +426,14 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bool quantize_weights_; bool quantize_weights_;
}; };
// Declare LSTMOpTest as a parameterized test, where the parameter is a boolean // Declare LSTMOpTest as a parameterized test.
// indicating whether to use quantization or not. class LSTMOpTest
class LSTMOpTest : public ::testing::TestWithParam<bool> {}; : public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool()); INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest,
::testing::Combine(
/*quantize_weights*/ ::testing::Bool(),
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int n_batch = 1; const int n_batch = 1;
@ -437,7 +442,9 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
const bool quantize_weights = GetParam(); auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -509,7 +516,8 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_forget tensor
{0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_cell tensor
{0}, // aux_bw_input_to_output tensor {0}, // aux_bw_input_to_output tensor
}); },
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,
@ -600,7 +608,9 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
const bool quantize_weights = GetParam(); auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -672,7 +682,8 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
{0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_forget tensor
{0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_cell tensor
{0}, // aux_bw_input_to_output tensor {0}, // aux_bw_input_to_output tensor
}); },
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,
@ -2631,7 +2642,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
const bool quantize_weights = GetParam(); auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -2703,7 +2716,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
{n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_forget tensor
{n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor
{n_cell, n_input}, // aux_bw_input_to_output tensor {n_cell, n_input}, // aux_bw_input_to_output tensor
}); },
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,
@ -2802,7 +2816,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
const bool quantize_weights = GetParam(); auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -2874,7 +2890,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
{n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_forget tensor
{n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor
{n_cell, n_input}, // aux_bw_input_to_output tensor {n_cell, n_input}, // aux_bw_input_to_output tensor
}); },
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,

View File

@ -27,6 +27,16 @@ namespace ops {
namespace builtin { namespace builtin {
namespace bidirectional_sequence_rnn { namespace bidirectional_sequence_rnn {
namespace {
struct OpData {
int scratch_tensor_index;
bool fw_compute_row_sums = false;
bool bw_compute_row_sums = false;
};
} // namespace
// LINT.IfChange // LINT.IfChange
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
@ -58,18 +68,23 @@ enum TemporaryTensor {
kFwHiddenStateQuantized = 1, kFwHiddenStateQuantized = 1,
kBwHiddenStateQuantized = 2, kBwHiddenStateQuantized = 2,
kScalingFactors = 3, kScalingFactors = 3,
kAuxInputQuantized = 4, kAccumScratch = 4,
kNumTemporaryTensors = 5 kZeroPoints = 5,
kFwRowSums = 6,
kBwRowSums = 7,
kAuxInputQuantized = 8,
kNumTemporaryTensors = 9
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int; auto* op_data = new OpData();
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); context->AddTensors(context, kNumTemporaryTensors,
return scratch_tensor_index; &op_data->scratch_tensor_index);
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -157,8 +172,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
if (IsHybridOp(input, fw_input_weights)) { if (IsHybridOp(input, fw_input_weights)) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
op_data->fw_compute_row_sums = true;
op_data->bw_compute_row_sums = true;
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
if (has_aux_input) { if (has_aux_input) {
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
@ -168,7 +184,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
*scratch_tensor_index + kInputQuantized; op_data->scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized); GetTemporary(context, node, kInputQuantized);
input_quantized->type = fw_input_weights->type; input_quantized->type = fw_input_weights->type;
@ -180,7 +196,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kFwHiddenStateQuantized] = node->temporaries->data[kFwHiddenStateQuantized] =
*scratch_tensor_index + kFwHiddenStateQuantized; op_data->scratch_tensor_index + kFwHiddenStateQuantized;
TfLiteTensor* fw_hidden_state_quantized = TfLiteTensor* fw_hidden_state_quantized =
GetTemporary(context, node, kFwHiddenStateQuantized); GetTemporary(context, node, kFwHiddenStateQuantized);
fw_hidden_state_quantized->type = fw_input_weights->type; fw_hidden_state_quantized->type = fw_input_weights->type;
@ -195,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kBwHiddenStateQuantized] = node->temporaries->data[kBwHiddenStateQuantized] =
*scratch_tensor_index + kBwHiddenStateQuantized; op_data->scratch_tensor_index + kBwHiddenStateQuantized;
TfLiteTensor* bw_hidden_state_quantized = TfLiteTensor* bw_hidden_state_quantized =
GetTemporary(context, node, kBwHiddenStateQuantized); GetTemporary(context, node, kBwHiddenStateQuantized);
bw_hidden_state_quantized->type = fw_input_weights->type; bw_hidden_state_quantized->type = fw_input_weights->type;
@ -211,7 +227,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store scaling factors of quantization. // Allocate temporary tensors to store scaling factors of quantization.
node->temporaries->data[kScalingFactors] = node->temporaries->data[kScalingFactors] =
*scratch_tensor_index + kScalingFactors; op_data->scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors); GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
@ -223,10 +239,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[kAccumScratch] =
op_data->scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = accum_scratch_dims[0];
accum_scratch_size->data[1] = accum_scratch_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points =
GetTemporary(context, node, /*index=*/kZeroPoints);
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
const int num_row_sums = has_aux_input ? 3 : 2;
node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums =
GetTemporary(context, node, /*index=*/kFwRowSums);
fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
fw_row_sums_size->data[0] = fw_row_sums_dims[0];
fw_row_sums_size->data[1] = fw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
fw_row_sums_size));
}
node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node,
/*index=*/kBwRowSums);
bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
bw_row_sums_size));
}
if (has_aux_input) { if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] = node->temporaries->data[kAuxInputQuantized] =
*scratch_tensor_index + kAuxInputQuantized; op_data->scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
GetTemporary(context, node, kAuxInputQuantized); GetTemporary(context, node, kAuxInputQuantized);
aux_input_quantized->type = fw_input_weights->type; aux_input_quantized->type = fw_input_weights->type;
@ -418,7 +490,10 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
TfLiteTensor* bw_output) { TfLiteTensor* bw_output, TfLiteTensor* zero_points,
TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
bool* bw_compute_row_sums) {
const bool time_major = params->time_major; const bool time_major = params->time_major;
const int batch_size = const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0]; (time_major) ? input->dims->data[1] : input->dims->data[0];
@ -464,11 +539,20 @@ TfLiteStatus EvalHybrid(
int8_t* bw_quantized_hidden_state_ptr = int8_t* bw_quantized_hidden_state_ptr =
GetTensorData<int8_t>(bw_hidden_state_quantized); GetTensorData<int8_t>(bw_hidden_state_quantized);
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
int32_t* zero_points_ptr = nullptr;
int32_t* fw_row_sums_ptr = nullptr;
int32_t* bw_row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
}
const int fw_output_step = const int fw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
const int bw_output_step = const int bw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
if (time_major) { if (time_major) {
for (int t = 0; t < max_time; t++) { for (int t = 0; t < max_time; t++) {
// Forward cell. // Forward cell.
@ -491,7 +575,9 @@ TfLiteStatus EvalHybrid(
fw_num_units, batch_size, fw_output_step, params->activation, fw_num_units, batch_size, fw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
fw_quantized_hidden_state_ptr, scaling_factors_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr,
fw_hidden_state_ptr_batch, output_ptr_batch); fw_hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
} }
// Backward cell. // Backward cell.
float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state); float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
@ -516,7 +602,9 @@ TfLiteStatus EvalHybrid(
bw_num_units, batch_size, bw_output_step, params->activation, bw_num_units, batch_size, bw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
bw_quantized_hidden_state_ptr, scaling_factors_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr,
bw_hidden_state_ptr_batch, output_ptr_batch); bw_hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
} }
} }
} else { } else {
@ -545,7 +633,9 @@ TfLiteStatus EvalHybrid(
fw_num_units, /*batch_size=*/1, fw_output_step, params->activation, fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
fw_quantized_hidden_state_ptr, scaling_factors_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr,
fw_hidden_state_ptr_batch, output_ptr_batch); fw_hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
} }
// Backward cell. // Backward cell.
float* bw_hidden_state_ptr_batch = float* bw_hidden_state_ptr_batch =
@ -574,7 +664,9 @@ TfLiteStatus EvalHybrid(
bw_num_units, /*batch_size=*/1, bw_output_step, params->activation, bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
bw_quantized_hidden_state_ptr, scaling_factors_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr,
bw_hidden_state_ptr_batch, output_ptr_batch); bw_hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
} }
} }
} }
@ -656,17 +748,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, kBwHiddenStateQuantized); GetTemporary(context, node, kBwHiddenStateQuantized);
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors); GetTemporary(context, node, kScalingFactors);
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr; : nullptr;
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights, return EvalHybrid(
fw_bias, bw_input_weights, bw_recurrent_weights, input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
bw_bias, real_aux_input, fw_aux_input_weights, bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
bw_aux_input_weights, params, scaling_factors, fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
input_quantized, aux_input_quantized, input_quantized, aux_input_quantized, fw_hidden_state_quantized,
fw_hidden_state_quantized, fw_hidden_state, fw_output, fw_hidden_state, fw_output, bw_hidden_state_quantized,
bw_hidden_state_quantized, bw_hidden_state, bw_output); bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
bw_row_sums, &op_data->fw_compute_row_sums,
&op_data->bw_compute_row_sums);
} }
default: default:
context->ReportError(context, "Type not currently supported."); context->ReportError(context, "Type not currently supported.");

View File

@ -662,20 +662,24 @@ class BidirectionalRNNOpModel : public SingleOpModel {
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
int bw_units, int input_size, int aux_input_size, int bw_units, int input_size, int aux_input_size,
AuxInputMode aux_input_mode, bool time_major, AuxInputMode aux_input_mode, bool time_major,
bool merge_outputs) bool merge_outputs, bool quantize_weights = false,
bool asymmetric_quantize_weights = false)
: batches_(batches), : batches_(batches),
sequence_len_(sequence_len), sequence_len_(sequence_len),
fw_units_(fw_units), fw_units_(fw_units),
bw_units_(bw_units), bw_units_(bw_units),
input_size_(input_size), input_size_(input_size),
aux_input_size_(aux_input_size) { aux_input_size_(aux_input_size),
quantize_weights_(quantize_weights) {
const TensorType tensor_type =
quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
fw_weights_ = AddInput(TensorType_FLOAT32); fw_weights_ = AddInput(tensor_type);
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); fw_recurrent_weights_ = AddInput(tensor_type);
fw_bias_ = AddInput(TensorType_FLOAT32); fw_bias_ = AddInput(TensorType_FLOAT32);
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true); fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_weights_ = AddInput(TensorType_FLOAT32); bw_weights_ = AddInput(tensor_type);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); bw_recurrent_weights_ = AddInput(tensor_type);
bw_bias_ = AddInput(TensorType_FLOAT32); bw_bias_ = AddInput(TensorType_FLOAT32);
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
@ -697,8 +701,8 @@ class BidirectionalRNNOpModel : public SingleOpModel {
} }
if (aux_input_mode == AuxInputMode::kCrossLinking) { if (aux_input_mode == AuxInputMode::kCrossLinking) {
aux_fw_weights_ = AddInput(TensorType_FLOAT32); aux_fw_weights_ = AddInput(tensor_type);
aux_bw_weights_ = AddInput(TensorType_FLOAT32); aux_bw_weights_ = AddInput(tensor_type);
aux_fw_weights_shape = {fw_units, aux_input_size_}; aux_fw_weights_shape = {fw_units, aux_input_size_};
aux_bw_weights_shape = {bw_units, aux_input_size_}; aux_bw_weights_shape = {bw_units, aux_input_size_};
@ -712,12 +716,12 @@ class BidirectionalRNNOpModel : public SingleOpModel {
bw_output_ = AddOutput(TensorType_FLOAT32); bw_output_ = AddOutput(TensorType_FLOAT32);
} }
SetBuiltinOp( SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_BidirectionalSequenceRNNOptions,
BuiltinOptions_BidirectionalSequenceRNNOptions, CreateBidirectionalSequenceRNNOptions(
CreateBidirectionalSequenceRNNOptions( builder_, time_major, ActivationFunctionType_RELU,
builder_, time_major, ActivationFunctionType_RELU, merge_outputs) merge_outputs, asymmetric_quantize_weights)
.Union()); .Union());
BuildInterpreter({ BuildInterpreter({
input_shape, // input input_shape, // input
@ -744,19 +748,35 @@ class BidirectionalRNNOpModel : public SingleOpModel {
} }
void SetFwWeights(const std::vector<float>& f) { void SetFwWeights(const std::vector<float>& f) {
PopulateTensor(fw_weights_, f); if (quantize_weights_) {
SymmetricQuantizeAndPopulate(fw_weights_, f);
} else {
PopulateTensor(fw_weights_, f);
}
} }
void SetBwWeights(const std::vector<float>& f) { void SetBwWeights(const std::vector<float>& f) {
PopulateTensor(bw_weights_, f); if (quantize_weights_) {
SymmetricQuantizeAndPopulate(bw_weights_, f);
} else {
PopulateTensor(bw_weights_, f);
}
} }
void SetFwRecurrentWeights(const std::vector<float>& f) { void SetFwRecurrentWeights(const std::vector<float>& f) {
PopulateTensor(fw_recurrent_weights_, f); if (quantize_weights_) {
SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
} else {
PopulateTensor(fw_recurrent_weights_, f);
}
} }
void SetBwRecurrentWeights(const std::vector<float>& f) { void SetBwRecurrentWeights(const std::vector<float>& f) {
PopulateTensor(bw_recurrent_weights_, f); if (quantize_weights_) {
SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
} else {
PopulateTensor(bw_recurrent_weights_, f);
}
} }
void SetInput(std::initializer_list<float> data) { void SetInput(std::initializer_list<float> data) {
@ -772,11 +792,19 @@ class BidirectionalRNNOpModel : public SingleOpModel {
} }
void SetAuxFwWeights(const std::vector<float>& f) { void SetAuxFwWeights(const std::vector<float>& f) {
PopulateTensor(aux_fw_weights_, f); if (quantize_weights_) {
SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
} else {
PopulateTensor(aux_fw_weights_, f);
}
} }
void SetAuxBwWeights(const std::vector<float>& f) { void SetAuxBwWeights(const std::vector<float>& f) {
PopulateTensor(aux_bw_weights_, f); if (quantize_weights_) {
SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
} else {
PopulateTensor(aux_bw_weights_, f);
}
} }
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); } std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
@ -811,17 +839,31 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_units_; int bw_units_;
int input_size_; int input_size_;
int aux_input_size_; int aux_input_size_;
bool quantize_weights_;
}; };
// Declare LSTMOpTest as a parameterized test.
class BidirectionalRNNOpTest
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
::testing::Combine(
/*quantize_weights*/ ::testing::Bool(),
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
// TODO(mirkov): add another test which directly compares to TF once TOCO // TODO(mirkov): add another test which directly compares to TF once TOCO
// supports the conversion from dynamic_rnn with BasicRNNCell. // supports the conversion from dynamic_rnn with BasicRNNCell.
TEST(BidirectionalRNNOpTest, BlackBoxTest) { TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16, /*fw_units=*/16, /*bw_units=*/16,
/*input_size=*/8, /*aux_input_size=*/0, /*input_size=*/8, /*aux_input_size=*/0,
/*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
/*time_major=*/false, /*time_major=*/false,
/*merge_outputs=*/false); /*merge_outputs=*/false, quantize_weights,
asymmetric_quantize_inputs);
rnn.SetFwWeights(weights); rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights); rnn.SetBwWeights(weights);
rnn.SetFwBias(biases); rnn.SetFwBias(biases);
@ -843,7 +885,9 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
std::vector<float> fw_expected; std::vector<float> fw_expected;
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); EXPECT_THAT(rnn.GetFwOutput(),
ElementsAreArray(ArrayFloatNear(
fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
float* golden_bw_start = rnn_golden_bw_output; float* golden_bw_start = rnn_golden_bw_output;
float* golden_bw_end = float* golden_bw_end =
@ -851,17 +895,23 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
std::vector<float> bw_expected; std::vector<float> bw_expected;
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); EXPECT_THAT(rnn.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(
bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
} }
// Same as BlackBox test, but input is reshuffled to time_major format. // Same as BlackBox test, but input is reshuffled to time_major format.
TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16, /*fw_units=*/16, /*bw_units=*/16,
/*input_size=*/8, /*aux_input_size=*/0, /*input_size=*/8, /*aux_input_size=*/0,
/*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
/*time_major=*/true, /*time_major=*/true,
/*merge_outputs=*/false); /*merge_outputs=*/false, quantize_weights,
asymmetric_quantize_inputs);
rnn.SetFwWeights(weights); rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights); rnn.SetBwWeights(weights);
rnn.SetFwBias(biases); rnn.SetFwBias(biases);
@ -889,17 +939,26 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
} }
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); constexpr float kHybridTolerance = 3.57e-1;
constexpr float kFloatTolerance = 1e-5;
EXPECT_THAT(
rnn.GetFwOutput(),
ElementsAreArray(ArrayFloatNear(
fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
} }
// Same as BlackBox test, yet with merged outputs. // Same as BlackBox test, yet with merged outputs.
TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16, /*fw_units=*/16, /*bw_units=*/16,
/*input_size=*/8, /*aux_input_size=*/0, /*input_size=*/8, /*aux_input_size=*/0,
/*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
/*time_major=*/false, /*time_major=*/false,
/*merge_outputs=*/true); /*merge_outputs=*/true, quantize_weights,
asymmetric_quantize_inputs);
rnn.SetFwWeights(weights); rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights); rnn.SetBwWeights(weights);
rnn.SetFwBias(biases); rnn.SetFwBias(biases);
@ -929,7 +988,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
} }
} }
EXPECT_THAT(rnn.GetFwOutput(), EXPECT_THAT(rnn.GetFwOutput(),
ElementsAreArray(ArrayFloatNear(merged_expected))); ElementsAreArray(ArrayFloatNear(
merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
} }
// Same as BlackBox test, but input is reshuffled to time_major format. // Same as BlackBox test, but input is reshuffled to time_major format.

View File

@ -71,6 +71,7 @@ struct OpData {
int32_t output_activation_max; int32_t output_activation_max;
// The index of the temporary tensor where the quantized inputs are cached. // The index of the temporary tensor where the quantized inputs are cached.
int scratch_tensor_index; int scratch_tensor_index;
bool compute_row_sums = false;
}; };
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
@ -131,7 +132,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to carry information from Prepare() to // Instead, we allocate a new object to carry information from Prepare() to
// Eval(). // Eval().
auto* op_data = new OpData(); auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/3, context->AddTensors(context, /*tensors_to_add=*/5,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
return op_data; return op_data;
} }
@ -144,7 +145,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Check we have all the inputs and outputs we need. // Check we have all the inputs and outputs we need.
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3); TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
// Shuffled formats need a workspace to store the shuffled input activations. // Shuffled formats need a workspace to store the shuffled input activations.
@ -208,7 +208,8 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && if (input->type == kTfLiteFloat32 &&
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) { (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(3); data->compute_row_sums = true;
node->temporaries = TfLiteIntArrayCreate(5);
node->temporaries->data[0] = data->scratch_tensor_index; node->temporaries->data[0] = data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
@ -245,6 +246,28 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
node->temporaries->data[3] = data->scratch_tensor_index + 3;
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
input_offsets_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
input_offsets_size));
}
node->temporaries->data[4] = data->scratch_tensor_index + 4;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_units};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
row_sums_size->data[0] = row_sums_dims[0];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
// Resize output. // Resize output.
@ -332,7 +355,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data, TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* input_quantized, const TfLiteTensor* bias, TfLiteTensor* input_quantized,
TfLiteTensor* scaling_factors, TfLiteTensor* output) { TfLiteTensor* scaling_factors,
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
TfLiteTensor* input_offsets, TfLiteTensor* output) {
int total_input_size = 1; int total_input_size = 1;
for (int i = 0; i < input->dims->size; i++) { for (int i = 0; i < input->dims->size; i++) {
total_input_size *= input->dims->data[i]; total_input_size *= input->dims->data[i];
@ -363,32 +388,39 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
// Quantize input from float to uint8 + quantization params (scaling factor). // Quantize input from float to uint8 + quantization params (scaling factor).
float unused_min, unused_max; float unused_min, unused_max;
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* input_offset_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
input_offset_ptr = GetTensorData<int32_t>(input_offsets);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
int8_t* quant_data = GetTensorData<int8_t>(input_quantized); int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
const int8_t* filter_data = GetTensorData<int8_t>(filter); const int8_t* filter_data = GetTensorData<int8_t>(filter);
const float* input_ptr = GetTensorData<float>(input);
// Quantize each batch independently. // Quantize each batch independently.
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats( if (params->asymmetric_quantize_inputs) {
GetTensorData<float>(input) + offset, input_size, quant_data + offset, tensor_utils::AsymmetricQuantizeFloats(
&unused_min, &unused_max, &scaling_factors_ptr[b]); input_ptr + offset, input_size, quant_data + offset,
&scaling_factors_ptr[b], &input_offset_ptr[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, input_size, quant_data + offset, &unused_min,
&unused_max, &scaling_factors_ptr[b]);
}
// Incorporate scaling of the filter. // Incorporate scaling of the filter.
scaling_factors_ptr[b] *= filter->params.scale; scaling_factors_ptr[b] *= filter->params.scale;
} }
// Compute output += weight * quantized_input // Compute output += weight * quantized_input
#ifdef TFLITE_WITH_RUY_GEMV
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
int32_t* scratch = GetTensorData<int32_t>(accum_scratch); int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr, filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, scratch, GetTensorData<float>(output), batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, GetTensorData<float>(output));
#endif
// Apply activation function to floats. // Apply activation function to floats.
tensor_utils::ApplyActivationToVector( tensor_utils::ApplyActivationToVector(
GetTensorData<float>(output), batch_size * num_units, params->activation, GetTensorData<float>(output), batch_size * num_units, params->activation,
@ -461,8 +493,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
if (input->type == kTfLiteFloat32) { if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
return EvalHybrid(context, node, params, data, input, filter, bias, return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, output); input_quantized, scaling_factors, accum_scratch, row_sums,
input_offsets, output);
} else { } else {
FullyConnectedParams op_params; FullyConnectedParams op_params;
op_params.input_offset = input_offset; op_params.input_offset = input_offset;
@ -590,7 +626,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
FullyConnectedParams op_params; FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min; op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max; op_params.float_activation_max = output_activation_max;
reference_ops::FullyConnected( reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<float>(input), op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter), GetTensorShape(filter), GetTensorData<float>(filter),

View File

@ -286,7 +286,8 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
public: public:
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input, HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
const TensorData& weights, const TensorData& weights,
const TensorData& output = {TensorType_FLOAT32}) const TensorData& output = {TensorType_FLOAT32},
bool asymmetric_inputs = false)
: batches_(batches), units_(units) { : batches_(batches), units_(units) {
int total_input_size = 1; int total_input_size = 1;
for (size_t i = 0; i < input.shape.size(); ++i) { for (size_t i = 0; i < input.shape.size(); ++i) {
@ -302,10 +303,13 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
output_ = AddOutput(output); output_ = AddOutput(output);
SetBuiltinOp( auto options = CreateFullyConnectedOptions(
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, builder_, ActivationFunctionType_RELU,
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
.Union()); false, asymmetric_inputs)
.Union();
SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
BuiltinOptions_FullyConnectedOptions, options);
resolver_ = absl::make_unique<SingleOpResolver>( resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, BuiltinOperator_FULLY_CONNECTED,
ops::builtin::Register_FULLY_CONNECTED_PIE()); ops::builtin::Register_FULLY_CONNECTED_PIE());
@ -867,6 +871,66 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
/*max_abs_error=*/1.3f))); /*max_abs_error=*/1.3f)));
} }
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}},
/*weights=*/
{TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
/*asymmetric_quantize_input*/ true); // Hybrid asymmetric
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetBias({1, 2, 3});
m.SetInput({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, 25, 26, //
58, 59, 60, //
},
/*max_abs_error=*/0.64f)));
}
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}},
/*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
{TensorType_FLOAT32},
/*asymmetric_quantize_input*/ true);
m.SetSignedWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetBias({1, 2, 3});
m.SetInput({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, 25, 26, //
58, 59, 60, //
},
/*max_abs_error=*/1.3f)));
}
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of // Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in // batches. All we care is that the input can be evenly distributed in

View File

@ -123,7 +123,9 @@ void RnnBatchStep(
int num_units, int batch_size, int output_batch_leading_dim, int num_units, int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch) { float* hidden_state_ptr_batch, float* output_ptr_batch,
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale, RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
/*aux_input_ptr_batch=*/nullptr, /*aux_input_ptr_batch=*/nullptr,
/*aux_input_weights_ptr=*/nullptr, /*aux_input_weights_ptr=*/nullptr,
@ -133,7 +135,29 @@ void RnnBatchStep(
output_batch_leading_dim, activation, quantized_input_ptr_batch, output_batch_leading_dim, activation, quantized_input_ptr_batch,
/*aux_quantized_input_ptr_batch=*/nullptr, /*aux_quantized_input_ptr_batch=*/nullptr,
quantized_hidden_state_ptr_batch, scaling_factors, quantized_hidden_state_ptr_batch, scaling_factors,
hidden_state_ptr_batch, output_ptr_batch); hidden_state_ptr_batch, output_ptr_batch,
asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums,
compute_row_sums);
}
void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums,
int32_t* recurrent_row_sums, int32_t* row_sums,
const float* aux_input_ptr_batch, int num_units,
int input_size, int aux_input_size,
const int8_t* input_weights_ptr,
const int8_t* aux_input_weights_ptr,
const int8_t* recurrent_weights_ptr) {
memset(input_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units,
input_size);
if (aux_input_ptr_batch) {
memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums,
num_units, aux_input_size);
}
memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums,
num_units, num_units);
} }
void RnnBatchStep( void RnnBatchStep(
@ -146,9 +170,31 @@ void RnnBatchStep(
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* aux_quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch) { float* hidden_state_ptr_batch, float* output_ptr_batch,
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
// Since the output batch rows may not be contiguous (output_batch_leading_dim // Since the output batch rows may not be contiguous (output_batch_leading_dim
// != n_output), we unroll the batched operations where this is the case. // != n_output), we unroll the batched operations where this is the case.
int32_t* input_row_sums = nullptr;
int32_t* aux_input_row_sums = nullptr;
int32_t* recurrent_row_sums = nullptr;
if (asymmetric_quantize_inputs) {
input_row_sums = row_sums;
aux_input_row_sums = row_sums;
if (aux_input_ptr_batch) {
aux_input_row_sums += num_units;
}
recurrent_row_sums = aux_input_row_sums + num_units;
if (*compute_row_sums) {
ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums,
row_sums, aux_input_ptr_batch, num_units, input_size,
aux_input_size, input_weights_ptr,
aux_input_weights_ptr, recurrent_weights_ptr);
*compute_row_sums = false;
}
}
if (output_batch_leading_dim == num_units) { if (output_batch_leading_dim == num_units) {
// Output = bias // Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
@ -163,17 +209,25 @@ void RnnBatchStep(
// whichever is faster. // whichever is faster.
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats( if (asymmetric_quantize_inputs) {
input_ptr_batch + offset, input_size, tensor_utils::AsymmetricQuantizeFloats(
quantized_input_ptr_batch + offset, &unused_min, &unused_max, input_ptr_batch + offset, input_size,
&scaling_factors[b]); quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
input_ptr_batch + offset, input_size,
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
scaling_factors[b] *= input_weights_scale; scaling_factors[b] *= input_weights_scale;
} }
// Output += input * input_weights // Output += input * input_weights
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
scaling_factors, batch_size, output_ptr_batch); scaling_factors, batch_size, output_ptr_batch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch,
input_row_sums, compute_row_sums, /*context=*/nullptr);
} }
if (aux_input_ptr_batch && if (aux_input_ptr_batch &&
@ -182,10 +236,17 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * aux_input_size; const int offset = b * aux_input_size;
tensor_utils::SymmetricQuantizeFloats( if (asymmetric_quantize_inputs) {
aux_input_ptr_batch + offset, aux_input_size, tensor_utils::AsymmetricQuantizeFloats(
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, aux_input_ptr_batch + offset, aux_input_size,
&scaling_factors[b]); aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr_batch + offset, aux_input_size,
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
scaling_factors[b] *= aux_input_weights_scale; scaling_factors[b] *= aux_input_weights_scale;
} }
@ -193,7 +254,9 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_weights_ptr, num_units, aux_input_size, aux_input_weights_ptr, num_units, aux_input_size,
aux_quantized_input_ptr_batch, scaling_factors, batch_size, aux_quantized_input_ptr_batch, scaling_factors, batch_size,
output_ptr_batch); output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch, aux_input_row_sums, compute_row_sums,
/*context=*/nullptr);
} }
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
@ -203,10 +266,17 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * num_units; const int offset = b * num_units;
tensor_utils::SymmetricQuantizeFloats( if (asymmetric_quantize_inputs) {
hidden_state_ptr_batch + offset, num_units, tensor_utils::AsymmetricQuantizeFloats(
quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, hidden_state_ptr_batch + offset, num_units,
&scaling_factors[b]); quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
hidden_state_ptr_batch + offset, num_units,
quantized_hidden_state_ptr_batch + offset, &unused_min,
&unused_max, &scaling_factors[b]);
}
scaling_factors[b] *= recurrent_weights_scale; scaling_factors[b] *= recurrent_weights_scale;
} }
@ -214,7 +284,9 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, recurrent_weights_ptr, num_units, num_units,
quantized_hidden_state_ptr_batch, scaling_factors, batch_size, quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
output_ptr_batch); output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch, recurrent_row_sums, compute_row_sums,
/*context=*/nullptr);
} }
// Output = activation(Output) and update hidden_state // Output = activation(Output) and update hidden_state
@ -238,10 +310,17 @@ void RnnBatchStep(
// whichever is faster. // whichever is faster.
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats( if (asymmetric_quantize_inputs) {
input_ptr_batch + offset, input_size, tensor_utils::AsymmetricQuantizeFloats(
quantized_input_ptr_batch + offset, &unused_min, &unused_max, input_ptr_batch + offset, input_size,
&scaling_factors[b]); quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
input_ptr_batch + offset, input_size,
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
scaling_factors[b] *= input_weights_scale; scaling_factors[b] *= input_weights_scale;
} }
@ -250,7 +329,9 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, input_weights_ptr, num_units, input_size,
quantized_input_ptr_batch + k * input_size, &scaling_factors[k], quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
input_row_sums, compute_row_sums, /*context=*/nullptr);
} }
} }
@ -260,10 +341,17 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * aux_input_size; const int offset = b * aux_input_size;
tensor_utils::SymmetricQuantizeFloats( if (asymmetric_quantize_inputs) {
aux_input_ptr_batch + offset, aux_input_size, tensor_utils::AsymmetricQuantizeFloats(
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, aux_input_ptr_batch + offset, aux_input_size,
&scaling_factors[b]); aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr_batch + offset, aux_input_size,
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
scaling_factors[b] *= aux_input_weights_scale; scaling_factors[b] *= aux_input_weights_scale;
} }
@ -273,7 +361,9 @@ void RnnBatchStep(
aux_input_weights_ptr, num_units, aux_input_size, aux_input_weights_ptr, num_units, aux_input_size,
aux_quantized_input_ptr_batch + k * aux_input_size, aux_quantized_input_ptr_batch + k * aux_input_size,
&scaling_factors[k], &scaling_factors[k],
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
aux_input_row_sums, compute_row_sums, /*context=*/nullptr);
} }
} }
@ -284,10 +374,17 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * num_units; const int offset = b * num_units;
tensor_utils::SymmetricQuantizeFloats( if (asymmetric_quantize_inputs) {
hidden_state_ptr_batch + offset, num_units, tensor_utils::AsymmetricQuantizeFloats(
quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, hidden_state_ptr_batch + offset, num_units,
&scaling_factors[b]); quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats(
hidden_state_ptr_batch + offset, num_units,
quantized_hidden_state_ptr_batch + offset, &unused_min,
&unused_max, &scaling_factors[b]);
}
scaling_factors[b] *= recurrent_weights_scale; scaling_factors[b] *= recurrent_weights_scale;
} }
@ -296,8 +393,10 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, recurrent_weights_ptr, num_units, num_units,
quantized_hidden_state_ptr_batch + k * num_units, quantized_hidden_state_ptr_batch + k * num_units,
&scaling_factors[k], &scaling_factors[k], /*n_batch=*/1,
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); output_ptr_batch + k * output_batch_leading_dim,
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
recurrent_row_sums, compute_row_sums, /*context=*/nullptr);
} }
} }

View File

@ -70,7 +70,9 @@ void RnnBatchStep(
int num_units, int batch_size, int output_batch_leading_dim, int num_units, int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch); float* hidden_state_ptr_batch, float* output_ptr_batch,
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
void RnnBatchStep( void RnnBatchStep(
const float* input_ptr_batch, const int8_t* input_weights_ptr, const float* input_ptr_batch, const int8_t* input_weights_ptr,
@ -82,7 +84,9 @@ void RnnBatchStep(
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* aux_quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch); float* hidden_state_ptr_batch, float* output_ptr_batch,
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
} // namespace kernel_utils } // namespace kernel_utils
} // namespace tflite } // namespace tflite

View File

@ -1310,6 +1310,13 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1); const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1); const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
int32_t* row_sums_ptr = row_sums;
if (row_sums == nullptr) {
row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows);
NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
}
for (int batch = 0; batch < n_batch; ++batch) { for (int batch = 0; batch < n_batch; ++batch) {
const float batch_scaling_factor = scaling_factors[batch]; const float batch_scaling_factor = scaling_factors[batch];
const int batch_input_offset = input_offset[batch]; const int batch_input_offset = input_offset[batch];
@ -1327,10 +1334,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
// Initialize the dot product sum for the row to 0. // Initialize the dot product sum for the row to 0.
int32x4_t dotprod_32x4 = vmovq_n_s32(0); int32x4_t dotprod_32x4 = vmovq_n_s32(0);
int32x4_t row_sum_32x4;
if (row_sums == nullptr) {
row_sum_32x4 = vmovq_n_s32(0);
}
// Prefetch the row to cache. // Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */, __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */); 3 /* temporal locality */);
@ -1358,10 +1361,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
prod_16x8 = prod_16x8 =
vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16)); vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
if (row_sums == nullptr) {
const int16x8_t row_sum_16x8 = vpaddlq_s8(s2_8x16);
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
}
} // for col } // for col
// Half iteration dealing only 8 elements // Half iteration dealing only 8 elements
@ -1375,29 +1374,24 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col)); const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8); const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
if (row_sums == nullptr) {
const int16x8_t row_sum_16x8 = vmovl_s8(s2_8x8);
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
}
col += (kWeightsPerNeonLane >> 1); col += (kWeightsPerNeonLane >> 1);
} }
int32_t dotprod = AccumulateNeonLane(dotprod_32x4); int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
int32_t row_sum = row_sums == nullptr ? AccumulateNeonLane(row_sum_32x4)
: row_sums[row];
// Postamble loop. // Postamble loop.
for (; col < m_cols; ++col) { for (; col < m_cols; ++col) {
dotprod += row_ptr[col] * aligned_vec[col]; dotprod += row_ptr[col] * aligned_vec[col];
if (row_sums == nullptr) {
row_sum += row_ptr[col];
}
} // for col } // for col
dotprod -= row_sum * batch_input_offset; dotprod -= row_sums_ptr[row] * batch_input_offset;
*result += dotprod * scale; *result += dotprod * scale;
++result; ++result;
} // for row } // for row
} // for batch } // for batch
if (row_sums == nullptr) {
free(row_sums_ptr);
}
if (unaligned) { if (unaligned) {
free(aligned_row_free); free(aligned_row_free);
} }
@ -1410,6 +1404,20 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, const float* per_channel_scale, int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) { bool* compute_row_sums, CpuBackendContext* context) {
if (input_offset == nullptr) {
#ifdef TFLITE_WITH_RUY_GEMV
if (context) {
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, scratch,
result, context);
return;
}
#endif
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, result);
return;
}
if (compute_row_sums == nullptr || *compute_row_sums) { if (compute_row_sums == nullptr || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * m_rows); memset(row_sums, 0, sizeof(int32_t) * m_rows);
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols); NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
@ -1419,7 +1427,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
} }
#ifdef TFLITE_WITH_RUY_GEMV #ifdef TFLITE_WITH_RUY_GEMV
if (m_rows % 4 == 0) { if (context != nullptr && m_rows % 4 == 0) {
const int32_t* bias = static_cast<const int32_t*>(nullptr); const int32_t* bias = static_cast<const int32_t*>(nullptr);
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0, NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
scratch, context); scratch, context);
@ -1463,9 +1471,9 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
for (; i < total_size; i++) { for (; i < total_size; i++) {
const float batch_scaling_factor = scaling_factors[i / m_rows]; const float batch_scaling_factor = scaling_factors[i / m_rows];
const int32_t zero_point = input_offset[i / m_rows]; const int32_t zero_point = input_offset[i / m_rows];
int32_t x = *(scratch_ptr++); int32_t dotprod = *(scratch_ptr++);
x -= row_sums[i % m_rows] * zero_point; dotprod -= row_sums[i % m_rows] * zero_point;
*result += x * batch_scaling_factor; *result += dotprod * batch_scaling_factor;
++result; ++result;
} }
return; return;

View File

@ -167,6 +167,11 @@ void SseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ scaling_factors, int n_batch, const float* __restrict__ scaling_factors, int n_batch,
float* __restrict__ result, const float* __restrict__ per_channel_scale, float* __restrict__ result, const float* __restrict__ per_channel_scale,
const int32_t* __restrict__ input_offset) { const int32_t* __restrict__ input_offset) {
if (input_offset == nullptr) {
SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, result);
return;
}
static constexpr std::intptr_t kBlockSize = 16; static constexpr std::intptr_t kBlockSize = 16;
for (std::intptr_t batch = 0; batch < n_batch; ++batch) { for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
const float batch_scaling_factor = scaling_factors[batch]; const float batch_scaling_factor = scaling_factors[batch];

View File

@ -59,9 +59,10 @@ void MatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, const float* per_channel_scale, int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) { bool* compute_row_sums, CpuBackendContext* context) {
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, PortableMatrixBatchVectorMultiplyAccumulate(
vectors, scaling_factors, n_batch, result, per_channel_scale, matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
input_offset, scratch, row_sums, compute_row_sums, context); per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
context);
} }
void MatrixBatchVectorMultiplyAccumulate( void MatrixBatchVectorMultiplyAccumulate(

View File

@ -196,6 +196,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, const float* per_channel_scale, int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) { bool* compute_row_sums, CpuBackendContext* context) {
if (input_offset == nullptr) {
PortableMatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
return;
}
if (!compute_row_sums || *compute_row_sums) { if (!compute_row_sums || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * m_rows); memset(row_sums, 0, sizeof(int32_t) * m_rows);
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols); PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);

View File

@ -223,7 +223,8 @@ inline void EvalHybridSVDF(
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
const TfLiteTensor* bias, const TfLiteSVDFParams* params, const TfLiteTensor* bias, const TfLiteSVDFParams* params,
TfLiteTensor* scratch, TfLiteTensor* scaling_factors, TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) { TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output,
TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) {
const int rank = params->rank; const int rank = params->rank;
const int batch_size = input->dims->data[0]; const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1]; const int input_size = input->dims->data[1];
@ -244,6 +245,13 @@ inline void EvalHybridSVDF(
float* output_ptr = GetTensorData<float>(output); float* output_ptr = GetTensorData<float>(output);
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
// Initialize the weights scale. // Initialize the weights scale.
const float weights_feature_scale = weights_feature->params.scale; const float weights_feature_scale = weights_feature->params.scale;
@ -258,21 +266,30 @@ inline void EvalHybridSVDF(
if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) { if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
// Quantize input from float to int8. // Quantize input from float to int8.
float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats( if (params->asymmetric_quantize_inputs) {
input_ptr + offset, input_size, quantized_input_ptr + offset, tensor_utils::AsymmetricQuantizeFloats(
&unused_min, &unused_max, &scaling_factors_ptr[b]); input_ptr + offset, input_size, quantized_input_ptr + offset,
&scaling_factors_ptr[b], &zero_points_ptr[b]);
} else {
// Quantize input from float to int8.
float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, input_size, quantized_input_ptr + offset,
&unused_min, &unused_max, &scaling_factors_ptr[b]);
}
scaling_factors_ptr[b] *= weights_feature_scale; scaling_factors_ptr[b] *= weights_feature_scale;
} }
// Compute conv1d(inputs, weights_feature). // Compute conv1d(inputs, weights_feature).
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
weights_feature_ptr, num_filters, input_size, quantized_input_ptr, weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
scaling_factors_ptr, batch_size, scratch_ptr); scaling_factors_ptr, batch_size, scratch_ptr,
/*per_channel_scale=*/nullptr, zero_points_ptr,
reinterpret_cast<int32_t*>(scratch_ptr), row_sums_ptr, compute_row_sums,
/*context=*/nullptr);
} }
// Copy the latest activation from scratch into activation_state: // Copy the latest activation from scratch into activation_state:
// The last, i.e. (memory_size-1)th entry for each batch, and filter. // The last, i.e. (memory_size-1)th entry for each batch, and filter.
for (int i = 0; i < batch_size * num_filters; ++i) { for (int i = 0; i < batch_size * num_filters; ++i) {

View File

@ -55,6 +55,7 @@ struct OpData {
// These fields are only used by full kernel. // These fields are only used by full kernel.
int scratch_tensor_index; int scratch_tensor_index;
lstm_eval::IntegerLstmParameter integer_lstm_param; lstm_eval::IntegerLstmParameter integer_lstm_param;
bool compute_row_sums;
}; };
namespace full { namespace full {
@ -727,7 +728,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* op_data = new OpData();
op_data->kernel_type = kTfLiteLSTMFullKernel; op_data->kernel_type = kTfLiteLSTMFullKernel;
context->AddTensors(context, /*tensors_to_add=*/8, context->AddTensors(context, /*tensors_to_add=*/10,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
return op_data; return op_data;
} }
@ -1236,7 +1237,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) { if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(8); node->temporaries = TfLiteIntArrayCreate(10);
} else if (is_integer) { } else if (is_integer) {
if (is_8x8_16) { if (is_8x8_16) {
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(6);
@ -1273,6 +1274,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
if (is_hybrid_op) { if (is_hybrid_op) {
op_data->compute_row_sums = true;
// Allocate temporary tensors to store quantized values of input, // Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors. // activation_state and cell_state tensors.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
@ -1370,6 +1372,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
node->temporaries->data[8] = op_data->scratch_tensor_index + 8;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {n_batch};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[9] = op_data->scratch_tensor_index + 9;
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
int row_sums_rows = use_cifg ? 6 : 8;
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
if (projection_weights != nullptr) {
row_sums_rows += ceil(n_output / n_cell);
}
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
const int row_sums_dims[2] = {row_sums_rows, n_cell};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
if (is_integer) { if (is_integer) {
@ -1556,6 +1593,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/6); GetTemporary(context, node, /*index=*/6);
TfLiteTensor* output_scratch_buffer = TfLiteTensor* output_scratch_buffer =
GetTemporary(context, node, /*index=*/7); GetTemporary(context, node, /*index=*/7);
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid( return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -1577,7 +1617,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_quantized, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, cell_state_quantized, activation_state, cell_state,
output_scratch_buffer, output, output_scratch_buffer, output, zero_points, row_sums, row_sums_size,
&op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
} else { } else {
const int num_intermediate_tensors = node->intermediates->size; const int num_intermediate_tensors = node->intermediates->size;

View File

@ -33,24 +33,93 @@ namespace builtin {
namespace lstm_eval { namespace lstm_eval {
namespace { namespace {
inline float GetTensorScale(const TfLiteTensor* tensor) { void ComputeRowSums(
return tensor == nullptr ? 1.0f : tensor->params.scale; int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
int n_input, int n_aux_input, int n_output,
const int8_t* input_to_input_weights_ptr,
const int8_t* input_to_forget_weights_ptr,
const int8_t* input_to_cell_weights_ptr,
const int8_t* input_to_output_weights_ptr,
const int8_t* aux_input_to_input_weights_ptr,
const int8_t* aux_input_to_forget_weights_ptr,
const int8_t* aux_input_to_cell_weights_ptr,
const int8_t* aux_input_to_output_weights_ptr,
const int8_t* recurrent_to_input_weights_ptr,
const int8_t* recurrent_to_forget_weights_ptr,
const int8_t* recurrent_to_cell_weights_ptr,
const int8_t* recurrent_to_output_weights_ptr,
const int8_t* projection_weights_ptr, bool use_cifg,
const float* aux_input_ptr) {
// Compute the row sums for dequantization
if (!use_cifg) {
memset(input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
input_to_input_row_sums, n_cell, n_input);
}
memset(input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
input_to_forget_row_sums, n_cell, n_input);
memset(input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
input_to_cell_row_sums, n_cell, n_input);
memset(input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
input_to_output_row_sums, n_cell, n_input);
if (aux_input_ptr) {
if (!use_cifg) {
memset(aux_input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
aux_input_to_input_row_sums, n_cell,
n_aux_input);
}
memset(aux_input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
aux_input_to_forget_row_sums, n_cell,
n_aux_input);
memset(aux_input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
aux_input_to_cell_row_sums, n_cell,
n_aux_input);
memset(aux_input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
aux_input_to_output_row_sums, n_cell,
n_aux_input);
}
if (!use_cifg) {
memset(recurrent_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
recurrent_to_input_row_sums, n_cell,
n_output);
}
memset(recurrent_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
recurrent_to_forget_row_sums, n_cell,
n_output);
memset(recurrent_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
recurrent_to_cell_row_sums, n_cell,
n_output);
memset(recurrent_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
recurrent_to_output_row_sums, n_cell,
n_output);
if (projection_weights_ptr != nullptr) {
memset(projection_weights_row_sums, 0, sizeof(int32_t) * n_output);
tensor_utils::ReductionSumVector(
projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
}
} }
inline void MatrixBatchVectorMultiplyAccumulate( inline float GetTensorScale(const TfLiteTensor* tensor) {
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, return tensor == nullptr ? 1.0f : tensor->params.scale;
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
CpuBackendContext* context) {
// TODO(b/148289189) Remove when Ruy GEMV is the default.
#ifdef TFLITE_WITH_RUY_GEMV
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch,
result, context);
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
#endif
} }
// Performs an LSTM batch inference step for input specified by input_ptr. // Performs an LSTM batch inference step for input specified by input_ptr.
@ -473,6 +542,8 @@ inline void LstmStepHybrid(
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
int32_t* zero_points, int32_t* row_sums, int row_sums_size,
bool* compute_row_sums, bool asymmetric_quantize_inputs,
CpuBackendContext* context) { CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepHybrid"); ruy::profiler::ScopeLabel label("LstmStepHybrid");
// Since we have already checked that weights are all there or none, we // Since we have already checked that weights are all there or none, we
@ -503,53 +574,131 @@ inline void LstmStepHybrid(
output_gate_scratch); output_gate_scratch);
} }
// For each batch and cell: compute input_weight * input. int32_t* input_to_input_row_sums = nullptr;
// Skip if input is all zeros. int32_t* input_to_forget_row_sums = nullptr;
int32_t* input_to_cell_row_sums = nullptr;
int32_t* input_to_output_row_sums = nullptr;
int32_t* aux_input_to_input_row_sums = nullptr;
int32_t* aux_input_to_forget_row_sums = nullptr;
int32_t* aux_input_to_cell_row_sums = nullptr;
int32_t* aux_input_to_output_row_sums = nullptr;
int32_t* recurrent_to_input_row_sums = nullptr;
int32_t* recurrent_to_forget_row_sums = nullptr;
int32_t* recurrent_to_cell_row_sums = nullptr;
int32_t* recurrent_to_output_row_sums = nullptr;
int32_t* projection_weights_row_sums = nullptr;
if (asymmetric_quantize_inputs) {
int num_row_sums = use_cifg ? 6 : 8;
if (aux_input_ptr != nullptr) {
num_row_sums += use_cifg ? 3 : 4;
}
if (projection_weights_ptr != nullptr) {
num_row_sums += ceil(n_output / n_cell);
}
TF_LITE_ASSERT(row_sums_size == num_row_sums);
input_to_input_row_sums = row_sums;
input_to_forget_row_sums =
use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
input_to_output_row_sums = input_to_cell_row_sums + n_cell;
if (aux_input_ptr != nullptr) {
aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
aux_input_to_forget_row_sums = use_cifg
? aux_input_to_input_row_sums
: aux_input_to_input_row_sums + n_cell;
aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
}
recurrent_to_input_row_sums = aux_input_ptr
? aux_input_to_output_row_sums + n_cell
: input_to_output_row_sums + n_cell;
recurrent_to_forget_row_sums = use_cifg
? recurrent_to_input_row_sums
: recurrent_to_input_row_sums + n_cell;
recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
if (projection_weights_ptr != nullptr) {
projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
}
if (*compute_row_sums) {
ComputeRowSums(
input_to_input_row_sums, input_to_forget_row_sums,
input_to_cell_row_sums, input_to_output_row_sums,
aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
input_to_cell_weights_ptr, input_to_output_weights_ptr,
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
projection_weights_ptr, use_cifg, aux_input_ptr);
*compute_row_sums = false;
}
}
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) { if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_input; const int offset = b * n_input;
float unused_min, unused_max; if (asymmetric_quantize_inputs) {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::AsymmetricQuantizeFloats(
input_ptr + offset, n_input, quantized_input_ptr + offset, input_ptr + offset, n_input, quantized_input_ptr + offset,
&unused_min, &unused_max, &scaling_factors[b]); &scaling_factors[b], &zero_points[b]);
} else {
float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, n_input, quantized_input_ptr + offset,
&unused_min, &unused_max, &scaling_factors[b]);
}
} }
if (!use_cifg) { if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_input_weights_scale; scaling_factors[b] * input_to_input_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, accum_scratch_ptr, product_scaling_factors, n_batch, input_gate_scratch,
input_gate_scratch, context); /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_input_row_sums, compute_row_sums, context);
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_forget_weights_scale; scaling_factors[b] * input_to_forget_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, accum_scratch_ptr, product_scaling_factors, n_batch, forget_gate_scratch,
forget_gate_scratch, context); /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_forget_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_cell_weights_scale; scaling_factors[b] * input_to_cell_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch, product_scaling_factors, n_batch, cell_scratch,
context); /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_cell_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_output_weights_scale; scaling_factors[b] * input_to_output_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, accum_scratch_ptr, product_scaling_factors, n_batch, output_gate_scratch,
output_gate_scratch, context); /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_output_row_sums, compute_row_sums, context);
} }
// For each batch and cell: compute aux_input_weight * aux_input. // For each batch and cell: compute aux_input_weight * aux_input.
@ -558,59 +707,84 @@ inline void LstmStepHybrid(
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) { !tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_aux_input; const int offset = b * n_aux_input;
float unused_min, unused_max; if (asymmetric_quantize_inputs) {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::AsymmetricQuantizeFloats(
aux_input_ptr + offset, n_aux_input, quantized_aux_input_ptr + offset, aux_input_ptr + offset, n_aux_input,
&unused_min, &unused_max, &scaling_factors[b]); quantized_aux_input_ptr + offset, &scaling_factors[b],
&zero_points[b]);
} else {
float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr + offset, n_aux_input,
quantized_aux_input_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
} }
if (!use_cifg) { if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_input_weights_scale; scaling_factors[b] * aux_input_to_input_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_to_input_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, input_gate_scratch, context); input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums,
context);
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_forget_weights_scale; scaling_factors[b] * aux_input_to_forget_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, forget_gate_scratch, context); forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums,
context);
row_sums += n_cell;
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_cell_weights_scale; scaling_factors[b] * aux_input_to_cell_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
accum_scratch_ptr, cell_scratch, context); /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_cell_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_output_weights_scale; scaling_factors[b] * aux_input_to_output_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_to_output_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, output_gate_scratch, context); output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums,
context);
} }
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_output; const int offset = b * n_output;
float unused_min, unused_max; if (asymmetric_quantize_inputs) {
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, tensor_utils::AsymmetricQuantizeFloats(
quantized_output_state_ptr + offset, output_state_ptr + offset, n_output,
&unused_min, &unused_max, quantized_output_state_ptr + offset, &scaling_factors[b],
&scaling_factors[b]); &zero_points[b]);
} else {
float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats(
output_state_ptr + offset, n_output,
quantized_output_state_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
} }
// For each batch and cell: compute recurrent_weight * output_state. // For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) { if (!use_cifg) {
@ -618,38 +792,46 @@ inline void LstmStepHybrid(
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale; scaling_factors[b] * recurrent_to_input_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, input_gate_scratch, context); input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums,
context);
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_forget_weights_scale; scaling_factors[b] * recurrent_to_forget_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, recurrent_to_forget_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, forget_gate_scratch, context); forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums,
context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_cell_weights_scale; scaling_factors[b] * recurrent_to_cell_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, cell_scratch, context); cell_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums,
context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_output_weights_scale; scaling_factors[b] * recurrent_to_output_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, recurrent_to_output_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, output_gate_scratch, context); output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums,
context);
} }
// For each batch and cell: update input gate. // For each batch and cell: update input gate.
@ -770,22 +952,32 @@ inline void LstmStepHybrid(
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_cell; const int offset = b * n_cell;
float unused_min, unused_max; if (asymmetric_quantize_inputs) {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::AsymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell, output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &unused_min, &unused_max, quantized_cell_state_ptr + offset, &scaling_factors[b],
&scaling_factors[b]); &zero_points[b]);
} else {
float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * projection_weights_scale; scaling_factors[b] * projection_weights_scale;
} }
for (int b = 0; b < n_batch; b++) { for (int b = 0; b < n_batch; b++) {
MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b], quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
/*n_batch=*/1, accum_scratch_ptr, /*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
output_ptr + b * output_batch_leading_dim, context); /*per_channel_scale=*/nullptr,
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,
accum_scratch_ptr, projection_weights_row_sums, compute_row_sums,
context);
} }
} }
if (params->proj_clip > 0.0) { if (params->proj_clip > 0.0) {
@ -1615,7 +1807,8 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, CpuBackendContext* context) { TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1]; const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch; int max_time, n_batch;
@ -1654,6 +1847,14 @@ TfLiteStatus EvalHybrid(
const int output_batch_leading_dim = const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1]; output->dims->data[output->dims->size - 1];
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
if (time_major) { if (time_major) {
// Feed the sequence into the LSTM step-by-step. // Feed the sequence into the LSTM step-by-step.
const int input_step = n_batch * n_input; const int input_step = n_batch * n_input;
@ -1721,7 +1922,9 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(output_state_quantized), GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized), GetTensorData<int8_t>(cell_state_quantized),
GetTensorData<float>(output_state), GetTensorData<float>(cell_state), GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
GetTensorData<int32_t>(output_scratch_buffer), output_ptr, context); GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
params->asymmetric_quantize_inputs, context);
} }
} else { } else {
for (int b = 0; b < n_batch; b++) { for (int b = 0; b < n_batch; b++) {
@ -1806,7 +2009,8 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(output_state_quantized), GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr, GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer), cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
output_ptr, context); output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size,
compute_row_sums, params->asymmetric_quantize_inputs, context);
} }
} }
} }

View File

@ -156,7 +156,8 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, CpuBackendContext* context); TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_16( TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,

View File

@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights, bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip, bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes, const std::vector<std::vector<int>>& input_shapes,
const TensorType weight_type, bool is_layer_norm) const TensorType weight_type, bool is_layer_norm,
bool asymmetric_quantize_inputs = false)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_cell_(n_cell), n_cell_(n_cell),
@ -129,10 +130,12 @@ class LSTMOpModel : public SingleOpModel {
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, SetBuiltinOp(
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
cell_clip, proj_clip) CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
.Union()); proj_clip, ::tflite::LSTMKernelType_FULL,
asymmetric_quantize_inputs)
.Union());
// Do not apply delegate yet since tensor values are not known (and more // Do not apply delegate yet since tensor values are not known (and more
// specifically scales in quantized tensors are not known). // specifically scales in quantized tensors are not known).
@ -315,7 +318,7 @@ class LSTMOpModel : public SingleOpModel {
const TensorType weight_type_; const TensorType weight_type_;
}; };
class BaseLstmTest : public ::testing::Test { class BaseLstmTest : public ::testing::TestWithParam<bool> {
protected: protected:
// Weights of the LSTM model. Some are optional. // Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_; std::vector<float> input_to_input_weights_;
@ -565,8 +568,11 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -604,14 +610,20 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false); /*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651); /*tolerance=*/0.0157651);
} }
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, class NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test
: public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -649,7 +661,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false); /*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651); /*tolerance=*/0.0157651);
@ -745,8 +757,11 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -784,13 +799,18 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false); /*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
} }
class CifgNoPeepholeNoProjectionNoClippingLstmInt8Test
: public CifgNoPeepholeNoProjectionNoClippingLstmTest {};
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -828,7 +848,7 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false); /*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
} }
@ -1474,50 +1494,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) { TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
const int n_output = 16;
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
@ -1554,11 +1535,60 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false); /*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
} }
class NoCifgPeepholeProjectionNoClippingLstmInt8Test
: public NoCifgPeepholeProjectionNoClippingLstmTest {};
TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
HybridLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
const int n_output = 16;
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015);
}
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
: public BaseLstmTest { : public BaseLstmTest {
void SetUp() override { void SetUp() override {
@ -1693,8 +1723,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
} }
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
HybridLayerNormLstmBlackBoxTestUint8) { HybridLayerNormLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 4; const int n_cell = 4;
@ -1741,7 +1774,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/true); /*is_layer_norm=*/true, GetParam());
lstm_golden_output_ = {{ lstm_golden_output_ = {{
// Batch0: 3 (input_sequence_size) * 3 (n_output) // Batch0: 3 (input_sequence_size) * 3 (n_output)
@ -1760,8 +1793,14 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
/*tolerance=*/0.0010907); /*tolerance=*/0.0010907);
} }
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, class NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test
: public NoCifgPeepholeProjectionNoClippingLayerNormLstmTest {};
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
HybridLayerNormLstmBlackBoxTestInt8) { HybridLayerNormLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 4; const int n_cell = 4;
@ -1808,22 +1847,24 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/true); /*is_layer_norm=*/true, GetParam());
// Goldens are calculated from weight_type=TensorType_FLOAT32.
lstm_golden_output_ = {{ lstm_golden_output_ = {{
// Batch0: 3 (input_sequence_size) * 3 (n_output) // Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0244576, 0.127847, -0.00181765, // seq 0 0.0244077, 0.128027, -0.00170918, // seq 0
0.0137518, 0.140892, 0.0402234, // seq 1 0.0137642, 0.140751, 0.0395835, // seq 1
-0.0048839, 0.155096, 0.0840309, // seq 2 -0.00459233, 0.155278, 0.0837378, // seq 2
}, },
{ {
// Batch1: 3 (input_sequence_size) * 3 (n_output) // Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.00728636, 0.0843957, 0.0634786, // seq 0 -0.00692428, 0.0848741, 0.063445, // seq 0
-0.00448382, 0.139278, 0.0737372, // seq 1 -0.00403911, 0.139963, 0.072681, // seq 1
0.00734616, 0.161793, 0.0560238, // seq 2 0.00752708, 0.161903, 0.0561371, // seq 2
}}; }};
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
/*tolerance=*/1.06e-3);
} }
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
@ -1940,8 +1981,11 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
} }
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
HybridLayerNormLstmBlackBoxTestUint8) { HybridLayerNormLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 4; const int n_cell = 4;
@ -1988,7 +2032,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/true); /*is_layer_norm=*/true, GetParam());
// Verify the final output. // Verify the final output.
lstm_golden_output_ = { lstm_golden_output_ = {
@ -2009,7 +2053,10 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
/*tolerance=*/0.000902065); /*tolerance=*/0.000902065);
} }
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test
: public CifgPeepholeProjectionNoClippingLayerNormLstmTest {};
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
HybridLayerNormLstmBlackBoxTestInt8) { HybridLayerNormLstmBlackBoxTestInt8) {
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
@ -2057,24 +2104,24 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/true); /*is_layer_norm=*/true, GetParam());
// Verify the final output. // Goldens are results using FLOAT32 inference.
lstm_golden_output_ = { lstm_golden_output_ = {{
{ // Batch0: 3 (input_sequence_size) * 3 (n_output)
// Batch0: 3 (input_sequence_size) * 3 (n_output) 0.0212971, 0.140816, 0.0112733, // seq 0
0.0212250091, 0.140474007, 0.0115012666, // seq 0 0.0132302, 0.152308, 0.0346313, // seq 1
0.0130806509, 0.152660668, 0.0347516984, // seq 1 -0.0123688, 0.16579, 0.0893078, // seq 2
-0.0124010444, 0.166042402, 0.0898982584, // seq 2 },
}, {
{ // Batch1: 3 (input_sequence_size) * 3 (n_output)
// Batch1: 3 (input_sequence_size) * 3 (n_output) -0.0226351, 0.0916948, 0.0769176, // seq 0
-0.0228835996, 0.0917588323, 0.0778886303, // seq 0 -0.0269967, 0.149708, 0.0941492, // seq 1
-0.0275101066, 0.148769245, 0.0938384682, // seq 1 -0.0103429, 0.173016, 0.0720509, // seq 2
-0.0103605557, 0.172605693, 0.0728750974, // seq 2 }};
}};
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
/*tolerance=*/1e-3);
} }
class LSTMIntegerOpModel : public SingleOpModel { class LSTMIntegerOpModel : public SingleOpModel {
@ -3311,5 +3358,22 @@ TEST(LSTMOpModel, InvalidTypeTest) {
""); "");
} }
#endif #endif
#define QUANTIZE_PARAMETER_TEST(test) \
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::Bool())
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmTest);
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmTest);
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmTest);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmInt8Test);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest);
QUANTIZE_PARAMETER_TEST(
NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmTest);
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
#undef QUANTIZE_PARAMETER_TEST
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -43,6 +43,7 @@ struct OpData {
int effective_scale_1_b; int effective_scale_1_b;
int32 effective_scale_2_a; int32 effective_scale_2_a;
int effective_scale_2_b; int effective_scale_2_b;
bool compute_row_sums = false;
}; };
} // namespace } // namespace
@ -61,8 +62,8 @@ constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* op_data = new OpData();
op_data->float_weights_time_initialized = false; op_data->float_weights_time_initialized = false;
// Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise. // Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise.
context->AddTensors(context, /*tensors_to_add=*/4, context->AddTensors(context, /*tensors_to_add=*/6,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
return op_data; return op_data;
} }
@ -130,7 +131,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize scratch. // Resize scratch.
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) { if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(4); node->temporaries = TfLiteIntArrayCreate(6);
} else if (is_full_integer) { } else if (is_full_integer) {
node->temporaries = TfLiteIntArrayCreate(2); node->temporaries = TfLiteIntArrayCreate(2);
} else { } else {
@ -156,6 +157,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_size_array)); scratch_size_array));
if (is_hybrid_op) { if (is_hybrid_op) {
op_data->compute_row_sums = true;
// Tell interpreter to allocate temporary tensors to store quantized values // Tell interpreter to allocate temporary tensors to store quantized values
// of input tensors. // of input tensors.
node->temporaries->data[1] = scratch_tensor_index + 1; node->temporaries->data[1] = scratch_tensor_index + 1;
@ -195,6 +197,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, float_weights_time, context->ResizeTensor(context, float_weights_time,
float_weights_time_size)); float_weights_time_size));
} }
node->temporaries->data[4] = scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = zero_points_dims[0];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[5] = scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
row_sums->type = kTfLiteFloat32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_filters};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
row_sums_size->data[0] = row_sums_dims[0];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
if (is_full_integer) { if (is_full_integer) {
// Allocated one extra tensor. // Allocated one extra tensor.
@ -267,7 +293,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/2); GetTemporary(context, node, /*index=*/2);
TfLiteTensor* float_weights_time = TfLiteTensor* float_weights_time =
GetTemporary(context, node, /*index=*/3); GetTemporary(context, node, /*index=*/3);
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
// Dequantize weights time. // Dequantize weights time.
// TODO(alanchiao): this dequantization initialization only needs to // TODO(alanchiao): this dequantization initialization only needs to
// happen once per model and should theoretically be placed in either // happen once per model and should theoretically be placed in either
@ -285,10 +312,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
op_data->float_weights_time_initialized = true; op_data->float_weights_time_initialized = true;
} }
reference_ops::EvalHybridSVDF(context, node, input, weights_feature,
float_weights_time, bias, params, scratch, reference_ops::EvalHybridSVDF(
scaling_factors, input_quantized, context, node, input, weights_feature, float_weights_time, bias,
activation_state, output); params, scratch, scaling_factors, input_quantized, activation_state,
output, zero_points, row_sums, &op_data->compute_row_sums);
return kTfLiteOk; return kTfLiteOk;
} else { } else {
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>( auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(

View File

@ -131,7 +131,8 @@ class BaseSVDFOpModel : public SingleOpModel {
BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
int rank, int rank,
TensorType weights_feature_type = TensorType_FLOAT32, TensorType weights_feature_type = TensorType_FLOAT32,
TensorType weights_time_type = TensorType_FLOAT32) TensorType weights_time_type = TensorType_FLOAT32,
bool asymmetric_quantize_inputs = false)
: batches_(batches), : batches_(batches),
units_(units), units_(units),
input_size_(input_size), input_size_(input_size),
@ -146,9 +147,10 @@ class BaseSVDFOpModel : public SingleOpModel {
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
/*is_variable=*/true); /*is_variable=*/true);
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp( SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE,
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); asymmetric_quantize_inputs)
.Union());
BuildInterpreter({ BuildInterpreter({
{batches_, input_size_}, // input tensor {batches_, input_size_}, // input tensor
{units_ * rank, input_size_}, // weights_feature tensor {units_ * rank, input_size_}, // weights_feature tensor
@ -203,9 +205,10 @@ class SVDFOpModel : public BaseSVDFOpModel {
class HybridSVDFOpModel : public BaseSVDFOpModel { class HybridSVDFOpModel : public BaseSVDFOpModel {
public: public:
HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
int rank, TensorType tensor_type) int rank, TensorType tensor_type,
bool asymmetric_quantize_inputs)
: BaseSVDFOpModel(batches, units, input_size, memory_size, rank, : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
tensor_type, tensor_type) { tensor_type, tensor_type, asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -229,7 +232,7 @@ class HybridSVDFOpModel : public BaseSVDFOpModel {
TensorType tensor_type_; TensorType tensor_type_;
}; };
class SVDFOpTest : public ::testing::Test { class SVDFOpTest : public ::testing::TestWithParam<bool> {
protected: protected:
void VerifyGoldens(float golden_input[], float golden_output[], void VerifyGoldens(float golden_input[], float golden_output[],
int golden_size, BaseSVDFOpModel* svdf, int golden_size, BaseSVDFOpModel* svdf,
@ -262,6 +265,9 @@ class SVDFOpTest : public ::testing::Test {
} }
}; };
INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest,
::testing::ValuesIn({false, true}));
TEST_F(SVDFOpTest, BlackBoxTestRank1) { TEST_F(SVDFOpTest, BlackBoxTestRank1) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1); /*memory_size=*/10, /*rank=*/1);
@ -325,9 +331,10 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) {
&svdf); &svdf);
} }
TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1, TensorType_UINT8); /*memory_size=*/10, /*rank=*/1, TensorType_UINT8,
GetParam());
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.22197971, 0.12416199, 0.27901134, 0.27557442,
0.3905206, -0.36137494, -0.06634006, -0.10640851}); 0.3905206, -0.36137494, -0.06634006, -0.10640851});
@ -347,12 +354,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.002945); /*tolerance=*/0.004285);
} }
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2, TensorType_UINT8); /*memory_size=*/10, /*rank=*/2, TensorType_UINT8,
GetParam());
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.12416199, 0.15785322, 0.27901134, 0.3905206,
0.21931258, -0.36137494, -0.10640851, 0.31053296, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
@ -387,12 +395,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.00625109); /*tolerance=*/0.007175);
} }
TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) { TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1, TensorType_INT8); /*memory_size=*/10, /*rank=*/1, TensorType_INT8,
GetParam());
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.22197971, 0.12416199, 0.27901134, 0.27557442,
0.3905206, -0.36137494, -0.06634006, -0.10640851}); 0.3905206, -0.36137494, -0.06634006, -0.10640851});
@ -412,12 +421,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.002945); /*tolerance=*/0.004285);
} }
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) { TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2, TensorType_INT8); /*memory_size=*/10, /*rank=*/2, TensorType_INT8,
GetParam());
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.12416199, 0.15785322, 0.27901134, 0.3905206,
0.21931258, -0.36137494, -0.10640851, 0.31053296, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
@ -452,7 +462,7 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.00625109); /*tolerance=*/0.007175);
} }
// Test case for full integer quantization of SVDF. // Test case for full integer quantization of SVDF.

View File

@ -33,6 +33,7 @@ struct OpData {
bool is_layer_norm_lstm; bool is_layer_norm_lstm;
// The scratch tensor index. // The scratch tensor index.
int scratch_tensor_index; int scratch_tensor_index;
bool compute_row_sums = false;
}; };
// Input Tensors of size {max_time, n_batch, n_input} // Input Tensors of size {max_time, n_batch, n_input}
@ -92,7 +93,9 @@ enum TemporaryTensor {
kProductScalingFactors = 5, kProductScalingFactors = 5,
kRecoveredCellWeights = 6, kRecoveredCellWeights = 6,
kAccumScratch = 7, kAccumScratch = 7,
kNumTemporaryTensors kZeroPoints = 8,
kRowSums = 9,
kNumTemporaryTensors = 10
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -408,6 +411,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_buffer_size)); scratch_buffer_size));
if (IsHybridOp(input, input_to_output_weights)) { if (IsHybridOp(input, input_to_output_weights)) {
op_data->compute_row_sums = true;
// Allocate temporary tensors to store quantized values of input, // Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors. // activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
@ -515,6 +519,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_rows = use_cifg ? 6 : 8;
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
if (projection_weights != nullptr) {
row_sums_rows += ceil(n_output / n_cell);
}
int row_sums_dims[2] = {row_sums_rows, n_cell};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -600,6 +632,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
lstm_params.activation = params->activation; lstm_params.activation = params->activation;
lstm_params.cell_clip = params->cell_clip; lstm_params.cell_clip = params->cell_clip;
lstm_params.proj_clip = params->proj_clip; lstm_params.proj_clip = params->proj_clip;
lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
switch (input_to_output_weights->type) { switch (input_to_output_weights->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -623,6 +656,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* activation_state_quantized = TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2); GetTemporary(context, node, /*index=*/2);
@ -635,6 +669,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/6); GetTemporary(context, node, /*index=*/6);
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
GetTemporary(context, node, /*index=*/kAccumScratch); GetTemporary(context, node, /*index=*/kAccumScratch);
TfLiteTensor* zero_points =
GetTemporary(context, node, /*index=*/kZeroPoints);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/kRowSums);
const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid( return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -654,7 +692,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
prod_scaling_factors, recovered_cell_weights, input_quantized, prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, accum_scratch, cell_state_quantized, activation_state, cell_state, accum_scratch,
output, CpuBackendContext::GetFromContext(context)); output, zero_points, row_sums, row_sums_size,
&op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
} }
default: default:
context->ReportError(context, "Type %d is not currently supported.", context->ReportError(context, "Type %d is not currently supported.",

View File

@ -38,7 +38,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
float proj_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes, const std::vector<std::vector<int>>& input_shapes,
const TensorType& weights_type = TensorType_FLOAT32, const TensorType& weights_type = TensorType_FLOAT32,
bool is_layer_norm = false) bool is_layer_norm = false,
bool asymmetric_quantize_inputs = false)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_cell_(n_cell), n_cell_(n_cell),
@ -131,7 +132,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
BuiltinOptions_UnidirectionalSequenceLSTMOptions, BuiltinOptions_UnidirectionalSequenceLSTMOptions,
CreateUnidirectionalSequenceLSTMOptions( CreateUnidirectionalSequenceLSTMOptions(
builder_, ActivationFunctionType_TANH, cell_clip, builder_, ActivationFunctionType_TANH, cell_clip,
proj_clip, time_major) proj_clip, time_major, asymmetric_quantize_inputs)
.Union()); .Union());
BuildInterpreter(input_shapes); BuildInterpreter(input_shapes);
} }
@ -292,11 +293,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
bool time_major, bool use_cifg, bool use_peephole, bool time_major, bool use_cifg, bool use_peephole,
bool use_projection_weights, bool use_projection_bias, float cell_clip, bool use_projection_weights, bool use_projection_bias, float cell_clip,
float proj_clip, const std::vector<std::vector<int>>& input_shapes, float proj_clip, const std::vector<std::vector<int>>& input_shapes,
TensorType tensor_type) TensorType tensor_type, bool asymmetric_quantize_inputs)
: UnidirectionalLSTMOpModel( : UnidirectionalLSTMOpModel(
n_batch, n_input, n_cell, n_output, sequence_length, time_major, n_batch, n_input, n_cell, n_output, sequence_length, time_major,
use_cifg, use_peephole, use_projection_weights, use_projection_bias, use_cifg, use_peephole, use_projection_weights, use_projection_bias,
cell_clip, proj_clip, input_shapes, tensor_type) { cell_clip, proj_clip, input_shapes, tensor_type, false,
asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -360,7 +362,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
TensorType tensor_type_; TensorType tensor_type_;
}; };
class BaseUnidirectionalLstmTest : public ::testing::Test { class BaseUnidirectionalLstmTest : public ::testing::TestWithParam<bool> {
protected: protected:
// Weights of the LSTM model. Some are optional. // Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_; std::vector<float> input_to_input_weights_;
@ -626,7 +628,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
/*time_major=*/false); /*time_major=*/false);
} }
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -668,7 +670,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_UINT8); TensorType_UINT8, GetParam());
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -689,7 +691,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
/*tolerance=*/0.0157651); /*tolerance=*/0.0157651);
} }
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -731,7 +733,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_INT8); TensorType_INT8, GetParam());
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -862,7 +864,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -880,11 +882,10 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{ {
{sequence_length, n_batch, n_input}, // input tensor {sequence_length, n_batch, n_input}, // input tensor
{0, 0}, // input_to_input_weight tensor {0, 0}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor {n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor {n_cell, n_input}, // input_to_output_weight tensor
{0, 0}, // recurrent_to_input_weight tensor {0, 0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor {n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor {n_cell, n_output}, // recurrent_to_cell_weight tensor
@ -905,7 +906,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_UINT8); TensorType_UINT8, GetParam());
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_);
@ -925,7 +926,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
} }
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -968,7 +969,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_INT8); TensorType_INT8, GetParam());
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_);
@ -1655,14 +1656,16 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
const int n_output = 16; const int n_output = 16;
const int sequence_length = 4; const int sequence_length = 4;
if (GetParam()) {
return;
}
HybridUnidirectionalLSTMOpModel lstm( HybridUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
@ -1697,7 +1700,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_UINT8); TensorType_UINT8, GetParam());
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -1723,8 +1726,11 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
} }
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
if (GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
@ -1765,7 +1771,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_INT8); TensorType_INT8, GetParam());
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -2737,5 +2743,14 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
#define QUANTIZE_PARAMETER_TEST(test) \
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true}));
QUANTIZE_PARAMETER_TEST(
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
QUANTIZE_PARAMETER_TEST(
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest);
#undef QUANTIZE_PARAMETER_TEST
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -26,6 +26,15 @@ namespace ops {
namespace builtin { namespace builtin {
namespace unidirectional_sequence_rnn { namespace unidirectional_sequence_rnn {
namespace {
struct OpData {
int scratch_tensor_index;
bool compute_row_sums = false;
};
} // namespace
// Input tensors. // Input tensors.
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1; constexpr int kWeightsTensor = 1;
@ -37,13 +46,14 @@ constexpr int kHiddenStateTensor = 4;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int; auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); context->AddTensors(context, /*tensors_to_add=*/6,
return scratch_tensor_index; &op_data->scratch_tensor_index);
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -96,10 +106,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store quantized values of input and // Allocate temporary tensors to store quantized values of input and
// hidden_state tensors. // hidden_state tensors.
if (is_hybrid) { if (is_hybrid) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); auto* op_data = reinterpret_cast<OpData*>(node->user_data);
op_data->compute_row_sums = true;
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(3); node->temporaries = TfLiteIntArrayCreate(6);
node->temporaries->data[0] = *scratch_tensor_index; node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = input_weights->type; input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
@ -108,7 +119,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
input_quantized_size)); input_quantized_size));
} }
node->temporaries->data[1] = *scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized = TfLiteTensor* hidden_state_quantized =
GetTemporary(context, node, /*index=*/1); GetTemporary(context, node, /*index=*/1);
hidden_state_quantized->type = input_weights->type; hidden_state_quantized->type = input_weights->type;
@ -121,7 +132,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized, context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size)); hidden_state_quantized_size));
} }
node->temporaries->data[2] = *scratch_tensor_index + 2; node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
@ -132,6 +143,42 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = accum_scratch_dims[0];
accum_scratch_size->data[1] = accum_scratch_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -202,7 +249,9 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
TfLiteTensor* hidden_state, TfLiteTensor* output) { TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points,
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
bool* compute_row_sums) {
const bool time_major = params->time_major; const bool time_major = params->time_major;
const int batch_size = const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0]; (time_major) ? input->dims->data[1] : input->dims->data[0];
@ -227,6 +276,14 @@ TfLiteStatus EvalHybrid(
float input_weights_scale = input_weights->params.scale; float input_weights_scale = input_weights->params.scale;
float recurrent_weights_scale = recurrent_weights->params.scale; float recurrent_weights_scale = recurrent_weights->params.scale;
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
if (time_major) { if (time_major) {
// Initialize the pointer to hidden state. // Initialize the pointer to hidden state.
@ -244,7 +301,9 @@ TfLiteStatus EvalHybrid(
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, num_units, params->activation, num_units, batch_size, num_units, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
hidden_state_ptr_batch, output_ptr_batch); hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
} }
} else { } else {
// For each batch // For each batch
@ -259,13 +318,14 @@ TfLiteStatus EvalHybrid(
s * input_size; s * input_size;
float* output_ptr_batch = GetTensorData<float>(output) + float* output_ptr_batch = GetTensorData<float>(output) +
b * num_units * max_time + s * num_units; b * num_units * max_time + s * num_units;
kernel_utils::RnnBatchStep( kernel_utils::RnnBatchStep(
input_ptr_batch, input_weights_ptr, input_weights_scale, input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
input_size, num_units, /*batch_size=*/1, num_units, input_size, num_units, /*batch_size=*/1, num_units,
params->activation, quantized_input_ptr, quantized_hidden_state_ptr, params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch,
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
} }
} }
} }
@ -274,7 +334,6 @@ TfLiteStatus EvalHybrid(
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights = const TfLiteTensor* recurrent_weights =
@ -292,12 +351,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
// TODO(mirkov): implement eval with quantized inputs as well. // TODO(mirkov): implement eval with quantized inputs as well.
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
return EvalHybrid(input, input_weights, recurrent_weights, bias, params, return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized, input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output); scaling_factors, hidden_state, output, zero_points,
accum_scratch, row_sums, &op_data->compute_row_sums);
} }
default: default:
context->ReportError(context, "Type %d not currently supported.", context->ReportError(context, "Type %d not currently supported.",

View File

@ -174,7 +174,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
UnidirectionalRNNOpModel( UnidirectionalRNNOpModel(
int batches, int sequence_len, int units, int size, bool time_major, int batches, int sequence_len, int units, int size, bool time_major,
const TensorType& weights = TensorType_FLOAT32, const TensorType& weights = TensorType_FLOAT32,
const TensorType& recurrent_weights = TensorType_FLOAT32) const TensorType& recurrent_weights = TensorType_FLOAT32,
bool asymmetric_quantize_inputs = false)
: batches_(batches), : batches_(batches),
sequence_len_(sequence_len), sequence_len_(sequence_len),
units_(units), units_(units),
@ -188,7 +189,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions, BuiltinOptions_SequenceRNNOptions,
CreateSequenceRNNOptions(builder_, time_major, CreateSequenceRNNOptions(builder_, time_major,
ActivationFunctionType_RELU) ActivationFunctionType_RELU,
asymmetric_quantize_inputs)
.Union()); .Union());
if (time_major) { if (time_major) {
BuildInterpreter({{sequence_len_, batches_, input_size_}, BuildInterpreter({{sequence_len_, batches_, input_size_},
@ -249,9 +251,11 @@ class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel {
public: public:
HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units, HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
int size, bool time_major, int size, bool time_major,
TensorType tensor_type) TensorType tensor_type,
bool asymmetric_quantize_inputs)
: UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major, : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
tensor_type, tensor_type) { tensor_type, tensor_type,
asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -297,10 +301,14 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
} }
TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { class HybridUnidirectionalRNNOpModelOpTest
: public ::testing::TestWithParam<bool> {};
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/false, TensorType_UINT8); /*time_major=*/false, TensorType_UINT8,
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -323,10 +331,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
expected, /*max_abs_error=*/0.013))); expected, /*max_abs_error=*/0.013)));
} }
TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) { TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/false, TensorType_INT8); /*time_major=*/false, TensorType_INT8,
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -378,10 +387,11 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
} }
TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/true, TensorType_UINT8); /*time_major=*/true, TensorType_UINT8,
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -408,10 +418,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
expected, /*max_abs_error=*/0.013))); expected, /*max_abs_error=*/0.013)));
} }
TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/true, TensorType_INT8); /*time_major=*/true, TensorType_INT8,
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -438,5 +449,9 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
expected, /*max_abs_error=*/0.013))); expected, /*max_abs_error=*/0.013)));
} }
INSTANTIATE_TEST_SUITE_P(HybridUnidirectionalRNNOpModelOpTest,
HybridUnidirectionalRNNOpModelOpTest,
::testing::ValuesIn({true, false}));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -519,17 +519,22 @@ table LSHProjectionOptions {
table SVDFOptions { table SVDFOptions {
rank:int; rank:int;
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
// For weights-only quantization, use asymmetric quantization for non
// constant inputs at evaluation time.
asymmetric_quantize_inputs:bool;
} }
// An implementation of TensorFlow RNNCell. // An implementation of TensorFlow RNNCell.
table RNNOptions { table RNNOptions {
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
asymmetric_quantize_inputs:bool;
} }
// An implementation of TensorFlow dynamic_rnn with RNNCell. // An implementation of TensorFlow dynamic_rnn with RNNCell.
table SequenceRNNOptions { table SequenceRNNOptions {
time_major:bool; time_major:bool;
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
asymmetric_quantize_inputs:bool;
} }
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
@ -537,6 +542,7 @@ table BidirectionalSequenceRNNOptions {
time_major:bool; time_major:bool;
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
merge_outputs: bool; merge_outputs: bool;
asymmetric_quantize_inputs:bool;
} }
enum FullyConnectedOptionsWeightsFormat: byte { enum FullyConnectedOptionsWeightsFormat: byte {
@ -556,6 +562,11 @@ table FullyConnectedOptions {
// If set to true, then the number of dimension is preserved. Furthermore, // If set to true, then the number of dimension is preserved. Furthermore,
// all but the last dimension of the input and output shapes will be equal. // all but the last dimension of the input and output shapes will be equal.
keep_num_dims: bool; keep_num_dims: bool;
// Parameters for FullyConnected version 7 or above.
// If set to true, then weights-only op will use asymmetric quantization for
// inputs.
asymmetric_quantize_inputs: bool;
} }
table SoftmaxOptions { table SoftmaxOptions {
@ -604,6 +615,9 @@ table LSTMOptions {
// Parameters for LSTM version 2 or above. // Parameters for LSTM version 2 or above.
// Basic kernel is only supported in version 2 or above. // Basic kernel is only supported in version 2 or above.
kernel_type: LSTMKernelType = FULL; kernel_type: LSTMKernelType = FULL;
// Parameters for LSTM version 4 or above.
asymmetric_quantize_inputs: bool;
} }
// An implementation of TensorFlow dynamic_rnn with LSTMCell. // An implementation of TensorFlow dynamic_rnn with LSTMCell.
@ -614,6 +628,9 @@ table UnidirectionalSequenceLSTMOptions {
// If true then first dimension is sequence, otherwise batch. // If true then first dimension is sequence, otherwise batch.
time_major:bool; time_major:bool;
// Parameter for Unidirectional Sequence LSTM version 4.
asymmetric_quantize_inputs:bool;
} }
table BidirectionalSequenceLSTMOptions { table BidirectionalSequenceLSTMOptions {
@ -630,6 +647,9 @@ table BidirectionalSequenceLSTMOptions {
// Version 1 implementations assumed time_major to be true, so this default // Version 1 implementations assumed time_major to be true, so this default
// value should never change. // value should never change.
time_major: bool = true; time_major: bool = true;
// Parameters for version 3 or above.
asymmetric_quantize_inputs:bool;
} }
table ResizeBilinearOptions { table ResizeBilinearOptions {

View File

@ -4216,9 +4216,11 @@ struct SVDFOptionsT : public flatbuffers::NativeTable {
typedef SVDFOptions TableType; typedef SVDFOptions TableType;
int32_t rank; int32_t rank;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool asymmetric_quantize_inputs;
SVDFOptionsT() SVDFOptionsT()
: rank(0), : rank(0),
fused_activation_function(tflite::ActivationFunctionType_NONE) { fused_activation_function(tflite::ActivationFunctionType_NONE),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4226,7 +4228,8 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef SVDFOptionsT NativeTableType; typedef SVDFOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_RANK = 4, VT_RANK = 4,
VT_FUSED_ACTIVATION_FUNCTION = 6 VT_FUSED_ACTIVATION_FUNCTION = 6,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
}; };
int32_t rank() const { int32_t rank() const {
return GetField<int32_t>(VT_RANK, 0); return GetField<int32_t>(VT_RANK, 0);
@ -4234,10 +4237,14 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_RANK) && VerifyField<int32_t>(verifier, VT_RANK) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4254,6 +4261,9 @@ struct SVDFOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4269,9 +4279,11 @@ struct SVDFOptionsBuilder {
inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions( inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
int32_t rank = 0, int32_t rank = 0,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
bool asymmetric_quantize_inputs = false) {
SVDFOptionsBuilder builder_(_fbb); SVDFOptionsBuilder builder_(_fbb);
builder_.add_rank(rank); builder_.add_rank(rank);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
} }
@ -4281,22 +4293,29 @@ flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBufferBuilde
struct RNNOptionsT : public flatbuffers::NativeTable { struct RNNOptionsT : public flatbuffers::NativeTable {
typedef RNNOptions TableType; typedef RNNOptions TableType;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool asymmetric_quantize_inputs;
RNNOptionsT() RNNOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE) { : fused_activation_function(tflite::ActivationFunctionType_NONE),
asymmetric_quantize_inputs(false) {
} }
}; };
struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef RNNOptionsT NativeTableType; typedef RNNOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_FUSED_ACTIVATION_FUNCTION = 4 VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 6
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4310,6 +4329,9 @@ struct RNNOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4324,8 +4346,10 @@ struct RNNOptionsBuilder {
inline flatbuffers::Offset<RNNOptions> CreateRNNOptions( inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
bool asymmetric_quantize_inputs = false) {
RNNOptionsBuilder builder_(_fbb); RNNOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
} }
@ -4336,9 +4360,11 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable {
typedef SequenceRNNOptions TableType; typedef SequenceRNNOptions TableType;
bool time_major; bool time_major;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool asymmetric_quantize_inputs;
SequenceRNNOptionsT() SequenceRNNOptionsT()
: time_major(false), : time_major(false),
fused_activation_function(tflite::ActivationFunctionType_NONE) { fused_activation_function(tflite::ActivationFunctionType_NONE),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4346,7 +4372,8 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef SequenceRNNOptionsT NativeTableType; typedef SequenceRNNOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_TIME_MAJOR = 4, VT_TIME_MAJOR = 4,
VT_FUSED_ACTIVATION_FUNCTION = 6 VT_FUSED_ACTIVATION_FUNCTION = 6,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
}; };
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
@ -4354,10 +4381,14 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4374,6 +4405,9 @@ struct SequenceRNNOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4389,8 +4423,10 @@ struct SequenceRNNOptionsBuilder {
inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions( inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
bool time_major = false, bool time_major = false,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
bool asymmetric_quantize_inputs = false) {
SequenceRNNOptionsBuilder builder_(_fbb); SequenceRNNOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
return builder_.Finish(); return builder_.Finish();
@ -4403,10 +4439,12 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable {
bool time_major; bool time_major;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool merge_outputs; bool merge_outputs;
bool asymmetric_quantize_inputs;
BidirectionalSequenceRNNOptionsT() BidirectionalSequenceRNNOptionsT()
: time_major(false), : time_major(false),
fused_activation_function(tflite::ActivationFunctionType_NONE), fused_activation_function(tflite::ActivationFunctionType_NONE),
merge_outputs(false) { merge_outputs(false),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4415,7 +4453,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_TIME_MAJOR = 4, VT_TIME_MAJOR = 4,
VT_FUSED_ACTIVATION_FUNCTION = 6, VT_FUSED_ACTIVATION_FUNCTION = 6,
VT_MERGE_OUTPUTS = 8 VT_MERGE_OUTPUTS = 8,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
}; };
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
@ -4426,11 +4465,15 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
bool merge_outputs() const { bool merge_outputs() const {
return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) && VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4450,6 +4493,9 @@ struct BidirectionalSequenceRNNOptionsBuilder {
void add_merge_outputs(bool merge_outputs) { void add_merge_outputs(bool merge_outputs) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0); fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4466,8 +4512,10 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
bool time_major = false, bool time_major = false,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
bool merge_outputs = false) { bool merge_outputs = false,
bool asymmetric_quantize_inputs = false) {
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_merge_outputs(merge_outputs); builder_.add_merge_outputs(merge_outputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
@ -4481,10 +4529,12 @@ struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
tflite::FullyConnectedOptionsWeightsFormat weights_format; tflite::FullyConnectedOptionsWeightsFormat weights_format;
bool keep_num_dims; bool keep_num_dims;
bool asymmetric_quantize_inputs;
FullyConnectedOptionsT() FullyConnectedOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT), weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT),
keep_num_dims(false) { keep_num_dims(false),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4493,7 +4543,8 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_WEIGHTS_FORMAT = 6, VT_WEIGHTS_FORMAT = 6,
VT_KEEP_NUM_DIMS = 8 VT_KEEP_NUM_DIMS = 8,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -4504,11 +4555,15 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
bool keep_num_dims() const { bool keep_num_dims() const {
return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0; return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) && VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) && VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4528,6 +4583,9 @@ struct FullyConnectedOptionsBuilder {
void add_keep_num_dims(bool keep_num_dims) { void add_keep_num_dims(bool keep_num_dims) {
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0); fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4544,8 +4602,10 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
bool keep_num_dims = false) { bool keep_num_dims = false,
bool asymmetric_quantize_inputs = false) {
FullyConnectedOptionsBuilder builder_(_fbb); FullyConnectedOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_keep_num_dims(keep_num_dims); builder_.add_keep_num_dims(keep_num_dims);
builder_.add_weights_format(weights_format); builder_.add_weights_format(weights_format);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
@ -4932,11 +4992,13 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
float cell_clip; float cell_clip;
float proj_clip; float proj_clip;
tflite::LSTMKernelType kernel_type; tflite::LSTMKernelType kernel_type;
bool asymmetric_quantize_inputs;
LSTMOptionsT() LSTMOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f), proj_clip(0.0f),
kernel_type(tflite::LSTMKernelType_FULL) { kernel_type(tflite::LSTMKernelType_FULL),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4946,7 +5008,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8, VT_PROJ_CLIP = 8,
VT_KERNEL_TYPE = 10 VT_KERNEL_TYPE = 10,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -4960,12 +5023,16 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
tflite::LSTMKernelType kernel_type() const { tflite::LSTMKernelType kernel_type() const {
return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0)); return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) && VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4988,6 +5055,9 @@ struct LSTMOptionsBuilder {
void add_kernel_type(tflite::LSTMKernelType kernel_type) { void add_kernel_type(tflite::LSTMKernelType kernel_type) {
fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0); fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -5005,10 +5075,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f, float proj_clip = 0.0f,
tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) { tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL,
bool asymmetric_quantize_inputs = false) {
LSTMOptionsBuilder builder_(_fbb); LSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_kernel_type(kernel_type); builder_.add_kernel_type(kernel_type);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
@ -5022,11 +5094,13 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
float cell_clip; float cell_clip;
float proj_clip; float proj_clip;
bool time_major; bool time_major;
bool asymmetric_quantize_inputs;
UnidirectionalSequenceLSTMOptionsT() UnidirectionalSequenceLSTMOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f), proj_clip(0.0f),
time_major(false) { time_major(false),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -5036,7 +5110,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8, VT_PROJ_CLIP = 8,
VT_TIME_MAJOR = 10 VT_TIME_MAJOR = 10,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -5050,12 +5125,16 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -5078,6 +5157,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
void add_time_major(bool time_major) { void add_time_major(bool time_major) {
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0); fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -5095,10 +5177,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f, float proj_clip = 0.0f,
bool time_major = false) { bool time_major = false,
bool asymmetric_quantize_inputs = false) {
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
@ -5113,12 +5197,14 @@ struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
float proj_clip; float proj_clip;
bool merge_outputs; bool merge_outputs;
bool time_major; bool time_major;
bool asymmetric_quantize_inputs;
BidirectionalSequenceLSTMOptionsT() BidirectionalSequenceLSTMOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f), proj_clip(0.0f),
merge_outputs(false), merge_outputs(false),
time_major(true) { time_major(true),
asymmetric_quantize_inputs(false) {
} }
}; };
@ -5129,7 +5215,8 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8, VT_PROJ_CLIP = 8,
VT_MERGE_OUTPUTS = 10, VT_MERGE_OUTPUTS = 10,
VT_TIME_MAJOR = 12 VT_TIME_MAJOR = 12,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 14
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -5146,6 +5233,9 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
@ -5153,6 +5243,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) && VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -5178,6 +5269,9 @@ struct BidirectionalSequenceLSTMOptionsBuilder {
void add_time_major(bool time_major) { void add_time_major(bool time_major) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1); fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -5196,10 +5290,12 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f, float proj_clip = 0.0f,
bool merge_outputs = false, bool merge_outputs = false,
bool time_major = true) { bool time_major = true,
bool asymmetric_quantize_inputs = false) {
BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
builder_.add_merge_outputs(merge_outputs); builder_.add_merge_outputs(merge_outputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
@ -11034,6 +11130,7 @@ inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_
(void)_resolver; (void)_resolver;
{ auto _e = rank(); _o->rank = _e; } { auto _e = rank(); _o->rank = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11046,10 +11143,12 @@ inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBuffe
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _rank = _o->rank; auto _rank = _o->rank;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateSVDFOptions( return tflite::CreateSVDFOptions(
_fbb, _fbb,
_rank, _rank,
_fused_activation_function); _fused_activation_function,
_asymmetric_quantize_inputs);
} }
inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11062,6 +11161,7 @@ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_fu
(void)_o; (void)_o;
(void)_resolver; (void)_resolver;
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11073,9 +11173,11 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(flatbuffers::FlatBufferB
(void)_o; (void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateRNNOptions( return tflite::CreateRNNOptions(
_fbb, _fbb,
_fused_activation_function); _fused_activation_function,
_asymmetric_quantize_inputs);
} }
inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11089,6 +11191,7 @@ inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuff
(void)_resolver; (void)_resolver;
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11101,10 +11204,12 @@ inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(flatbuff
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateSequenceRNNOptions( return tflite::CreateSequenceRNNOptions(
_fbb, _fbb,
_time_major, _time_major,
_fused_activation_function); _fused_activation_function,
_asymmetric_quantize_inputs);
} }
inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11119,6 +11224,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = merge_outputs(); _o->merge_outputs = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11132,11 +11238,13 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _merge_outputs = _o->merge_outputs; auto _merge_outputs = _o->merge_outputs;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateBidirectionalSequenceRNNOptions( return tflite::CreateBidirectionalSequenceRNNOptions(
_fbb, _fbb,
_time_major, _time_major,
_fused_activation_function, _fused_activation_function,
_merge_outputs); _merge_outputs,
_asymmetric_quantize_inputs);
} }
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11151,6 +11259,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = weights_format(); _o->weights_format = _e; } { auto _e = weights_format(); _o->weights_format = _e; }
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; } { auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11164,11 +11273,13 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(fl
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _weights_format = _o->weights_format; auto _weights_format = _o->weights_format;
auto _keep_num_dims = _o->keep_num_dims; auto _keep_num_dims = _o->keep_num_dims;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateFullyConnectedOptions( return tflite::CreateFullyConnectedOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_weights_format, _weights_format,
_keep_num_dims); _keep_num_dims,
_asymmetric_quantize_inputs);
} }
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11352,6 +11463,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
{ auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = cell_clip(); _o->cell_clip = _e; }
{ auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; }
{ auto _e = kernel_type(); _o->kernel_type = _e; } { auto _e = kernel_type(); _o->kernel_type = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11366,12 +11478,14 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
auto _cell_clip = _o->cell_clip; auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _kernel_type = _o->kernel_type; auto _kernel_type = _o->kernel_type;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateLSTMOptions( return tflite::CreateLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip, _proj_clip,
_kernel_type); _kernel_type,
_asymmetric_quantize_inputs);
} }
inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11387,6 +11501,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
{ auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = cell_clip(); _o->cell_clip = _e; }
{ auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; }
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11401,12 +11516,14 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
auto _cell_clip = _o->cell_clip; auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateUnidirectionalSequenceLSTMOptions( return tflite::CreateUnidirectionalSequenceLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip, _proj_clip,
_time_major); _time_major,
_asymmetric_quantize_inputs);
} }
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11423,6 +11540,7 @@ inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTM
{ auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; }
{ auto _e = merge_outputs(); _o->merge_outputs = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; }
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11438,13 +11556,15 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _merge_outputs = _o->merge_outputs; auto _merge_outputs = _o->merge_outputs;
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateBidirectionalSequenceLSTMOptions( return tflite::CreateBidirectionalSequenceLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip, _proj_clip,
_merge_outputs, _merge_outputs,
_time_major); _time_major,
_asymmetric_quantize_inputs);
} }
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {