Adds asymmetric quantized inputs for hybrid ops in future models.

PiperOrigin-RevId: 303262193
Change-Id: I13e2bddee0e9bf10af9d5911d004ca31be430401
This commit is contained in:
A. Unique TensorFlower 2020-03-26 22:22:07 -07:00 committed by TensorFlower Gardener
parent 857f0c9557
commit e8dbf1de1a
28 changed files with 490 additions and 1719 deletions

View File

@ -124,33 +124,21 @@ typedef struct {
typedef struct { typedef struct {
int rank; int rank;
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
// Parameter for SVDF version 4.
bool asymmetric_quantize_inputs;
} TfLiteSVDFParams; } TfLiteSVDFParams;
typedef struct { typedef struct {
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
// Parameter for RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteRNNParams; } TfLiteRNNParams;
typedef struct { typedef struct {
bool time_major; bool time_major;
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
// Parameter for Sequence RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteSequenceRNNParams; } TfLiteSequenceRNNParams;
typedef struct { typedef struct {
bool time_major; bool time_major;
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
bool merge_outputs; bool merge_outputs;
// Parameter for Bidirectional RNN verison 3.
bool asymmetric_quantize_inputs;
} TfLiteBidirectionalSequenceRNNParams; } TfLiteBidirectionalSequenceRNNParams;
typedef enum { typedef enum {
@ -170,11 +158,6 @@ typedef struct {
// tensors are the same. Furthermore, all but the last dimension of the input // tensors are the same. Furthermore, all but the last dimension of the input
// and output shapes will be equal. // and output shapes will be equal.
bool keep_num_dims; bool keep_num_dims;
// Parameters for FullyConnected version 7 or above.
// If set to true and the weights are quantized, then non constant inputs
// are quantized at evaluation time with asymmetric quantization.
bool asymmetric_quantize_inputs;
} TfLiteFullyConnectedParams; } TfLiteFullyConnectedParams;
typedef enum { typedef enum {
@ -245,9 +228,6 @@ typedef struct {
// Parameters for LSTM version 2. // Parameters for LSTM version 2.
// kTfLiteLSTMBasicKernel is only supported in version 2 or above. // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
TfLiteLSTMKernelType kernel_type; TfLiteLSTMKernelType kernel_type;
// Parameters for LSTM version 4.
bool asymmetric_quantize_inputs;
} TfLiteLSTMParams; } TfLiteLSTMParams;
typedef struct { typedef struct {
@ -258,9 +238,6 @@ typedef struct {
// If set to true then the first dimension is time, otherwise batch. // If set to true then the first dimension is time, otherwise batch.
bool time_major; bool time_major;
// Parameter for unidirectional sequence RNN version 3.
bool asymmetric_quantize_inputs;
} TfLiteUnidirectionalSequenceLSTMParams; } TfLiteUnidirectionalSequenceLSTMParams;
typedef struct { typedef struct {
@ -276,10 +253,6 @@ typedef struct {
// Parameters supported by version 2: // Parameters supported by version 2:
// If set to true then the first dimension is time, otherwise batch. // If set to true then the first dimension is time, otherwise batch.
bool time_major; bool time_major;
// Parameters supported by version 4:
// If set to true, then hybrid ops use asymmetric quantization for inputs.
bool asymmetric_quantize_inputs;
} TfLiteBidirectionalSequenceLSTMParams; } TfLiteBidirectionalSequenceLSTMParams;
typedef struct { typedef struct {

View File

@ -269,8 +269,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->rank = svdf_params->rank(); params->rank = svdf_params->rank();
params->activation = params->activation =
parse_activation(svdf_params->fused_activation_function()); parse_activation(svdf_params->fused_activation_function());
params->asymmetric_quantize_inputs =
svdf_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -282,8 +280,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->activation = params->activation =
parse_activation(sequence_rnn_params->fused_activation_function()); parse_activation(sequence_rnn_params->fused_activation_function());
params->time_major = sequence_rnn_params->time_major(); params->time_major = sequence_rnn_params->time_major();
params->asymmetric_quantize_inputs =
sequence_rnn_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -297,8 +293,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
bidi_sequence_rnn_params->fused_activation_function()); bidi_sequence_rnn_params->fused_activation_function());
params->time_major = bidi_sequence_rnn_params->time_major(); params->time_major = bidi_sequence_rnn_params->time_major();
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
params->asymmetric_quantize_inputs =
bidi_sequence_rnn_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -308,8 +302,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
params->activation = params->activation =
parse_activation(rnn_params->fused_activation_function()); parse_activation(rnn_params->fused_activation_function());
params->asymmetric_quantize_inputs =
rnn_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -331,8 +323,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->activation = parse_activation( params->activation = parse_activation(
fully_connected_params->fused_activation_function()); fully_connected_params->fused_activation_function());
params->keep_num_dims = fully_connected_params->keep_num_dims(); params->keep_num_dims = fully_connected_params->keep_num_dims();
params->asymmetric_quantize_inputs =
fully_connected_params->asymmetric_quantize_inputs();
switch (fully_connected_params->weights_format()) { switch (fully_connected_params->weights_format()) {
case FullyConnectedOptionsWeightsFormat_DEFAULT: case FullyConnectedOptionsWeightsFormat_DEFAULT:
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
@ -450,8 +440,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
lstm_params->kernel_type()); lstm_params->kernel_type());
return kTfLiteError; return kTfLiteError;
} }
params->asymmetric_quantize_inputs =
lstm_params->asymmetric_quantize_inputs();
} else { } else {
TF_LITE_REPORT_ERROR(error_reporter, TF_LITE_REPORT_ERROR(error_reporter,
"No valid LSTM builtin options exist"); "No valid LSTM builtin options exist");
@ -470,8 +458,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->cell_clip = seq_lstm_params->cell_clip(); params->cell_clip = seq_lstm_params->cell_clip();
params->proj_clip = seq_lstm_params->proj_clip(); params->proj_clip = seq_lstm_params->proj_clip();
params->time_major = seq_lstm_params->time_major(); params->time_major = seq_lstm_params->time_major();
params->asymmetric_quantize_inputs =
seq_lstm_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;
@ -487,8 +473,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->proj_clip = bidi_lstm_params->proj_clip(); params->proj_clip = bidi_lstm_params->proj_clip();
params->merge_outputs = bidi_lstm_params->merge_outputs(); params->merge_outputs = bidi_lstm_params->merge_outputs();
params->time_major = bidi_lstm_params->time_major(); params->time_major = bidi_lstm_params->time_major();
params->asymmetric_quantize_inputs =
bidi_lstm_params->asymmetric_quantize_inputs();
} }
*builtin_data = reinterpret_cast<void*>(params.release()); *builtin_data = reinterpret_cast<void*>(params.release());
break; break;

View File

@ -26,15 +26,6 @@ namespace ops {
namespace builtin { namespace builtin {
namespace rnn { namespace rnn {
namespace {
struct OpData {
int scratch_tensor_index;
bool compute_row_sums = false;
};
} // namespace
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1; constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2; constexpr int kRecurrentWeightsTensor = 2;
@ -45,14 +36,13 @@ constexpr int kHiddenStateTensor = 4;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* scratch_tensor_index = new int;
context->AddTensors(context, /*tensors_to_add=*/6, context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
&op_data->scratch_tensor_index); return scratch_tensor_index;
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<int*>(buffer);
} }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -99,11 +89,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store quantized values of input and // Allocate temporary tensors to store quantized values of input and
// hidden_state tensors. // hidden_state tensors.
if (is_hybrid) { if (is_hybrid) {
auto* op_data = reinterpret_cast<OpData*>(node->user_data); int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
op_data->compute_row_sums = true;
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = op_data->scratch_tensor_index; node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = input_weights->type; input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
@ -112,7 +101,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
input_quantized_size)); input_quantized_size));
} }
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = *scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized = TfLiteTensor* hidden_state_quantized =
GetTemporary(context, node, /*index=*/1); GetTemporary(context, node, /*index=*/1);
hidden_state_quantized->type = input_weights->type; hidden_state_quantized->type = input_weights->type;
@ -125,7 +114,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized, context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size)); hidden_state_quantized_size));
} }
node->temporaries->data[2] = op_data->scratch_tensor_index + 2; node->temporaries->data[2] = *scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
@ -136,43 +125,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = accum_scratch_dims[0];
accum_scratch_size->data[1] = accum_scratch_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -211,9 +165,7 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
TfLiteTensor* input_scratch, TfLiteTensor* input_scratch,
TfLiteTensor* hidden_state_scratch, TfLiteTensor* hidden_state_scratch,
TfLiteTensor* scaling_factors, TfLiteTensor* scaling_factors,
TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* hidden_state, TfLiteTensor* output) {
TfLiteTensor* zero_points, TfLiteTensor* accum_scratch,
TfLiteTensor* row_sums, bool* compute_row_sums) {
const int batch_size = input->dims->data[0]; const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0]; const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1]; const int input_size = input->dims->data[1];
@ -238,34 +190,26 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
int8_t* quantized_hidden_state_ptr = int8_t* quantized_hidden_state_ptr =
GetTensorData<int8_t>(hidden_state_scratch); GetTensorData<int8_t>(hidden_state_scratch);
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
kernel_utils::RnnBatchStep( kernel_utils::RnnBatchStep(
input_ptr_batch, input_weights_ptr, input_weights_scale, input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, output_batch_leading_dim, params->activation, num_units, batch_size, output_batch_leading_dim, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
hidden_state_ptr_batch, output_ptr_batch, hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr,
row_sums_ptr, compute_row_sums);
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights = const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor); GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
TfLiteTensor* hidden_state = TfLiteTensor* hidden_state =
&context->tensors[node->inputs->data[kHiddenStateTensor]]; GetVariableInput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// We already checked that weight types are consistent, so branch on one. // We already checked that weight types are consistent, so branch on one.
@ -279,13 +223,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
return EvalHybrid(input, input_weights, recurrent_weights, bias, params, return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized, input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output, zero_points, scaling_factors, hidden_state, output);
accum_scratch, row_sums, &op_data->compute_row_sums);
} }
default: default:
context->ReportError(context, "Type %d not currently supported.", context->ReportError(context, "Type %d not currently supported.",

View File

@ -175,8 +175,7 @@ class RNNOpModel : public SingleOpModel {
public: public:
RNNOpModel(int batches, int units, int size, RNNOpModel(int batches, int units, int size,
const TensorType& weights = TensorType_FLOAT32, const TensorType& weights = TensorType_FLOAT32,
const TensorType& recurrent_weights = TensorType_FLOAT32, const TensorType& recurrent_weights = TensorType_FLOAT32)
bool asymmetric_quantize_inputs = false)
: batches_(batches), units_(units), input_size_(size) { : batches_(batches), units_(units), input_size_(size) {
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
weights_ = AddInput(weights); weights_ = AddInput(weights);
@ -184,10 +183,9 @@ class RNNOpModel : public SingleOpModel {
bias_ = AddInput(TensorType_FLOAT32); bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddInput(TensorType_FLOAT32, true); hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_RNN, BuiltinOptions_RNNOptions, SetBuiltinOp(
CreateRNNOptions(builder_, ActivationFunctionType_RELU, BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
asymmetric_quantize_inputs) CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
.Union());
BuildInterpreter({{batches_, input_size_}, // input tensor BuildInterpreter({{batches_, input_size_}, // input tensor
{units_, input_size_}, // weights tensor {units_, input_size_}, // weights tensor
{units_, units_}, // recurrent weights tensor {units_, units_}, // recurrent weights tensor
@ -235,10 +233,8 @@ class RNNOpModel : public SingleOpModel {
// The hybrid model has quantized weights and recurrent_weights. // The hybrid model has quantized weights and recurrent_weights.
class HybridRNNOpModel : public RNNOpModel { class HybridRNNOpModel : public RNNOpModel {
public: public:
HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type, HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type)
bool asymmetric_quantize_inputs) : RNNOpModel(batches, units, size, tensor_type, tensor_type) {
: RNNOpModel(batches, units, size, tensor_type, tensor_type,
asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -286,10 +282,8 @@ TEST(RnnOpTest, BlackBoxTest) {
} }
} }
class HybridRnnOpTest : public ::testing::TestWithParam<bool> {}; TEST(HybridRnnOpTest, BlackBoxTestUint8) {
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8);
TEST_P(HybridRnnOpTest, BlackBoxTestUint8) {
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8, GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -316,8 +310,8 @@ TEST_P(HybridRnnOpTest, BlackBoxTestUint8) {
} }
} }
TEST_P(HybridRnnOpTest, BlackBoxTestInt8) { TEST(HybridRnnOpTest, BlackBoxTestInt8) {
HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8, GetParam()); HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8);
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -344,8 +338,5 @@ TEST_P(HybridRnnOpTest, BlackBoxTestInt8) {
} }
} }
INSTANTIATE_TEST_SUITE_P(HybridRnnOpTest, HybridRnnOpTest,
::testing::ValuesIn({false, true}));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -139,28 +139,18 @@ enum TemporaryTensor {
kProductScalingFactors = 8, kProductScalingFactors = 8,
kRecoveredCellWeights = 9, kRecoveredCellWeights = 9,
kAccumScratchBuffer = 10, kAccumScratchBuffer = 10,
kZeroPoints = 11, kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input.
kFwRowSums = 12, kNumTemporaryTensors
kBwRowSums = 13,
kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input.
kNumTemporaryTensors = 15
};
struct OpData {
int scratch_tensor_index;
bool compute_fw_row_sums = false;
bool compute_bw_row_sums = false;
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* scratch_tensor_index = new int;
context->AddTensors(context, kNumTemporaryTensors, context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
&op_data->scratch_tensor_index); return scratch_tensor_index;
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<int*>(buffer);
} }
// Check that input tensor dimensions matches with each other. // Check that input tensor dimensions matches with each other.
@ -395,7 +385,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Resize the output and scratch tensors based on the sizes of the input // Resize the output and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other. // tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = reinterpret_cast<OpData*>(node->user_data); int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
node->builtin_data); node->builtin_data);
@ -532,7 +522,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
} }
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
node->temporaries->data[kFwScratchBuffer] = op_data->scratch_tensor_index; node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
TfLiteTensor* fw_scratch_buffer = TfLiteTensor* fw_scratch_buffer =
GetTemporary(context, node, kFwScratchBuffer); GetTemporary(context, node, kFwScratchBuffer);
fw_scratch_buffer->type = input->type; fw_scratch_buffer->type = input->type;
@ -591,7 +581,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor. // Create a scratch buffer tensor.
node->temporaries->data[kBwScratchBuffer] = node->temporaries->data[kBwScratchBuffer] =
op_data->scratch_tensor_index + kBwScratchBuffer; *(scratch_tensor_index) + kBwScratchBuffer;
TfLiteTensor* bw_scratch_buffer = TfLiteTensor* bw_scratch_buffer =
GetTemporary(context, node, kBwScratchBuffer); GetTemporary(context, node, kBwScratchBuffer);
bw_scratch_buffer->type = input->type; bw_scratch_buffer->type = input->type;
@ -616,13 +606,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
bw_scratch_buffer_size)); bw_scratch_buffer_size));
if (is_hybrid_op) { if (is_hybrid_op) {
// Compute the row sums for cached zero_point offset calculation.
op_data->compute_fw_row_sums = true;
op_data->compute_bw_row_sums = true;
// Allocate temporary tensors to store quantized values of input, aux_input // Allocate temporary tensors to store quantized values of input, aux_input
// (if present), activation_state and cell_state tensors. // (if present), activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized; *scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized); GetTemporary(context, node, kInputQuantized);
input_quantized->type = fw_input_to_output_weights->type; input_quantized->type = fw_input_to_output_weights->type;
@ -634,7 +621,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kFwActivationStateQuantized] = node->temporaries->data[kFwActivationStateQuantized] =
op_data->scratch_tensor_index + kFwActivationStateQuantized; *scratch_tensor_index + kFwActivationStateQuantized;
TfLiteTensor* fw_activation_state_quantized = TfLiteTensor* fw_activation_state_quantized =
GetTemporary(context, node, kFwActivationStateQuantized); GetTemporary(context, node, kFwActivationStateQuantized);
fw_activation_state_quantized->type = fw_input_to_output_weights->type; fw_activation_state_quantized->type = fw_input_to_output_weights->type;
@ -648,7 +635,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_activation_state_quantized_size)); fw_activation_state_quantized_size));
} }
node->temporaries->data[kBwActivationStateQuantized] = node->temporaries->data[kBwActivationStateQuantized] =
op_data->scratch_tensor_index + kBwActivationStateQuantized; *scratch_tensor_index + kBwActivationStateQuantized;
TfLiteTensor* bw_activation_state_quantized = TfLiteTensor* bw_activation_state_quantized =
GetTemporary(context, node, kBwActivationStateQuantized); GetTemporary(context, node, kBwActivationStateQuantized);
bw_activation_state_quantized->type = fw_input_to_output_weights->type; bw_activation_state_quantized->type = fw_input_to_output_weights->type;
@ -662,7 +649,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_activation_state_quantized_size)); bw_activation_state_quantized_size));
} }
node->temporaries->data[kFwCellStateQuantized] = node->temporaries->data[kFwCellStateQuantized] =
op_data->scratch_tensor_index + kFwCellStateQuantized; *scratch_tensor_index + kFwCellStateQuantized;
TfLiteTensor* fw_cell_state_quantized = TfLiteTensor* fw_cell_state_quantized =
GetTemporary(context, node, kFwCellStateQuantized); GetTemporary(context, node, kFwCellStateQuantized);
fw_cell_state_quantized->type = fw_input_to_output_weights->type; fw_cell_state_quantized->type = fw_input_to_output_weights->type;
@ -676,7 +663,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_cell_state_quantized_size)); fw_cell_state_quantized_size));
} }
node->temporaries->data[kBwCellStateQuantized] = node->temporaries->data[kBwCellStateQuantized] =
op_data->scratch_tensor_index + kBwCellStateQuantized; *scratch_tensor_index + kBwCellStateQuantized;
TfLiteTensor* bw_cell_state_quantized = TfLiteTensor* bw_cell_state_quantized =
GetTemporary(context, node, kBwCellStateQuantized); GetTemporary(context, node, kBwCellStateQuantized);
bw_cell_state_quantized->type = fw_input_to_output_weights->type; bw_cell_state_quantized->type = fw_input_to_output_weights->type;
@ -696,7 +683,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// different matrices (which requires multiplying the scaling factors with // different matrices (which requires multiplying the scaling factors with
// the scaling factor of the matrix). // the scaling factor of the matrix).
node->temporaries->data[kScalingFactors] = node->temporaries->data[kScalingFactors] =
op_data->scratch_tensor_index + kScalingFactors; *scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors); GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
@ -709,7 +696,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[kProductScalingFactors] = node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors; *scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors = TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors); GetTemporary(context, node, kProductScalingFactors);
prod_scaling_factors->type = kTfLiteFloat32; prod_scaling_factors->type = kTfLiteFloat32;
@ -726,7 +713,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the recovered cell weights. Since // Allocate a temporary tensor to store the recovered cell weights. Since
// this is used for diagonal matrices, only need to store n_cell values. // this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] = node->temporaries->data[kRecoveredCellWeights] =
op_data->scratch_tensor_index + kRecoveredCellWeights; *scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights = TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights); GetTemporary(context, node, kRecoveredCellWeights);
recovered_cell_weights->type = kTfLiteFloat32; recovered_cell_weights->type = kTfLiteFloat32;
@ -743,7 +730,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the accumulated int32 values. // Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratchBuffer] = node->temporaries->data[kAccumScratchBuffer] =
op_data->scratch_tensor_index + kAccumScratchBuffer; *scratch_tensor_index + kAccumScratchBuffer;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer); GetTemporary(context, node, kAccumScratchBuffer);
accum_scratch->type = kTfLiteInt32; accum_scratch->type = kTfLiteInt32;
@ -763,72 +750,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
// Allocate temporary tensors for storing zero-points.
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
// Allocate temporary tensors for caching row sums for hybrid zero-point
// calculations.
int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
if (has_aux_input) {
fw_row_sums_rows += fw_use_cifg ? 3 : 4;
}
const TfLiteTensor* fw_projection_weights =
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
if (fw_projection_weights != nullptr) {
fw_row_sums_rows += ceil(n_fw_output / n_fw_cell);
}
node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
fw_hybrid_scratch_size));
}
int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
if (has_aux_input) {
bw_row_sums_rows += bw_use_cifg ? 3 : 4;
}
const TfLiteTensor* bw_projection_weights =
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
if (bw_projection_weights != nullptr) {
bw_row_sums_rows += ceil(n_bw_output / n_bw_cell);
}
node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
bw_row_sums_size));
}
// Only allocate a temporary tensor for quantized auxiliary input if we are // Only allocate a temporary tensor for quantized auxiliary input if we are
// actually going to use it. // actually going to use it.
if (has_aux_input) { if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] = node->temporaries->data[kAuxInputQuantized] =
op_data->scratch_tensor_index + kAuxInputQuantized; *scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
GetTemporary(context, node, kAuxInputQuantized); GetTemporary(context, node, kAuxInputQuantized);
aux_input_quantized->type = fw_input_to_output_weights->type; aux_input_quantized->type = fw_input_to_output_weights->type;
@ -849,7 +775,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
node->builtin_data); node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
// Input tensor. // Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@ -983,8 +909,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Populate a TfLiteLSTMParams struct for the evaluation functions. // Populate a TfLiteLSTMParams struct for the evaluation functions.
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
params->proj_clip, kTfLiteLSTMFullKernel, params->proj_clip, kTfLiteLSTMFullKernel};
params->asymmetric_quantize_inputs};
const int bw_output_offset = const int bw_output_offset =
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
@ -1078,11 +1003,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
: nullptr; : nullptr;
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer); GetTemporary(context, node, kAccumScratchBuffer);
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
const int fw_row_sums_size = fw_row_sums->dims->data[0];
const int bw_row_sums_size = bw_row_sums->dims->data[0];
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights, input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights, fw_input_to_cell_weights, fw_input_to_output_weights,
@ -1104,8 +1025,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recovered_cell_weights, input_quantized, aux_input_quantized, recovered_cell_weights, input_quantized, aux_input_quantized,
fw_activation_state_quantized, fw_cell_state_quantized, fw_activation_state_quantized, fw_cell_state_quantized,
fw_activation_state, fw_cell_state, accum_scratch, fw_output, fw_activation_state, fw_cell_state, accum_scratch, fw_output,
zero_points, fw_row_sums, fw_row_sums_size,
&op_data->compute_fw_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, fw_pass_status); TF_LITE_ENSURE_OK(context, fw_pass_status);
@ -1130,8 +1049,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recovered_cell_weights, input_quantized, aux_input_quantized, recovered_cell_weights, input_quantized, aux_input_quantized,
bw_activation_state_quantized, bw_cell_state_quantized, bw_activation_state_quantized, bw_cell_state_quantized,
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
zero_points, bw_row_sums, bw_row_sums_size,
&op_data->compute_bw_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, bw_pass_status); TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk; return kTfLiteOk;

View File

@ -40,8 +40,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bool use_projection_bias, bool merge_outputs, bool use_projection_bias, bool merge_outputs,
bool use_aux_input, float cell_clip, float proj_clip, bool use_aux_input, float cell_clip, float proj_clip,
bool quantize_weights, bool time_major, bool quantize_weights, bool time_major,
const std::vector<std::vector<int>>& input_shapes, const std::vector<std::vector<int>>& input_shapes)
bool asymmetric_quantize_inputs = false)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_fw_cell_(n_cell), n_fw_cell_(n_cell),
@ -208,12 +207,11 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_aux_input_to_output_weights_ = AddNullInput(); bw_aux_input_to_output_weights_ = AddNullInput();
} }
SetBuiltinOp( SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_BidirectionalSequenceLSTMOptions, BuiltinOptions_BidirectionalSequenceLSTMOptions,
CreateBidirectionalSequenceLSTMOptions( CreateBidirectionalSequenceLSTMOptions(
builder_, ActivationFunctionType_TANH, cell_clip, proj_clip, builder_, ActivationFunctionType_TANH, cell_clip,
merge_outputs, time_major, asymmetric_quantize_inputs) proj_clip, merge_outputs, time_major)
.Union()); .Union());
BuildInterpreter(input_shapes); BuildInterpreter(input_shapes);
} }
@ -426,14 +424,11 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bool quantize_weights_; bool quantize_weights_;
}; };
// Declare LSTMOpTest as a parameterized test. // Declare LSTMOpTest as a parameterized test, where the parameter is a boolean
class LSTMOpTest // indicating whether to use quantization or not.
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {}; class LSTMOpTest : public ::testing::TestWithParam<bool> {};
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool());
::testing::Combine(
/*quantize_weights*/ ::testing::Bool(),
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int n_batch = 1; const int n_batch = 1;
@ -442,9 +437,7 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
auto params = GetParam(); const bool quantize_weights = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -516,8 +509,7 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_forget tensor
{0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_cell tensor
{0}, // aux_bw_input_to_output tensor {0}, // aux_bw_input_to_output tensor
}, });
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,
@ -608,9 +600,7 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
auto params = GetParam(); const bool quantize_weights = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -682,8 +672,7 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
{0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_forget tensor
{0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_cell tensor
{0}, // aux_bw_input_to_output tensor {0}, // aux_bw_input_to_output tensor
}, });
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,
@ -2642,9 +2631,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
auto params = GetParam(); const bool quantize_weights = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -2716,8 +2703,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
{n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_forget tensor
{n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor
{n_cell, n_input}, // aux_bw_input_to_output tensor {n_cell, n_input}, // aux_bw_input_to_output tensor
}, });
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,
@ -2816,9 +2802,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
const int n_cell = 4; const int n_cell = 4;
const int n_output = 4; const int n_output = 4;
const int sequence_length = 3; const int sequence_length = 3;
auto params = GetParam(); const bool quantize_weights = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalLSTMOpModel lstm( BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
@ -2890,8 +2874,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
{n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_forget tensor
{n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor
{n_cell, n_input}, // aux_bw_input_to_output tensor {n_cell, n_input}, // aux_bw_input_to_output tensor
}, });
asymmetric_quantize_inputs);
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
-0.34550029, 0.04266912, -0.15680569, -0.34550029, 0.04266912, -0.15680569,

View File

@ -27,16 +27,6 @@ namespace ops {
namespace builtin { namespace builtin {
namespace bidirectional_sequence_rnn { namespace bidirectional_sequence_rnn {
namespace {
struct OpData {
int scratch_tensor_index;
bool fw_compute_row_sums = false;
bool bw_compute_row_sums = false;
};
} // namespace
// LINT.IfChange // LINT.IfChange
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
@ -68,23 +58,18 @@ enum TemporaryTensor {
kFwHiddenStateQuantized = 1, kFwHiddenStateQuantized = 1,
kBwHiddenStateQuantized = 2, kBwHiddenStateQuantized = 2,
kScalingFactors = 3, kScalingFactors = 3,
kAccumScratch = 4, kAuxInputQuantized = 4,
kZeroPoints = 5, kNumTemporaryTensors = 5
kFwRowSums = 6,
kBwRowSums = 7,
kAuxInputQuantized = 8,
kNumTemporaryTensors = 9
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* scratch_tensor_index = new int;
context->AddTensors(context, kNumTemporaryTensors, context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
&op_data->scratch_tensor_index); return scratch_tensor_index;
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<int*>(buffer);
} }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -172,9 +157,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
if (IsHybridOp(input, fw_input_weights)) { if (IsHybridOp(input, fw_input_weights)) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data); int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
op_data->fw_compute_row_sums = true;
op_data->bw_compute_row_sums = true;
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
if (has_aux_input) { if (has_aux_input) {
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
@ -184,7 +168,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
op_data->scratch_tensor_index + kInputQuantized; *scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized = TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized); GetTemporary(context, node, kInputQuantized);
input_quantized->type = fw_input_weights->type; input_quantized->type = fw_input_weights->type;
@ -196,7 +180,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kFwHiddenStateQuantized] = node->temporaries->data[kFwHiddenStateQuantized] =
op_data->scratch_tensor_index + kFwHiddenStateQuantized; *scratch_tensor_index + kFwHiddenStateQuantized;
TfLiteTensor* fw_hidden_state_quantized = TfLiteTensor* fw_hidden_state_quantized =
GetTemporary(context, node, kFwHiddenStateQuantized); GetTemporary(context, node, kFwHiddenStateQuantized);
fw_hidden_state_quantized->type = fw_input_weights->type; fw_hidden_state_quantized->type = fw_input_weights->type;
@ -211,7 +195,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
node->temporaries->data[kBwHiddenStateQuantized] = node->temporaries->data[kBwHiddenStateQuantized] =
op_data->scratch_tensor_index + kBwHiddenStateQuantized; *scratch_tensor_index + kBwHiddenStateQuantized;
TfLiteTensor* bw_hidden_state_quantized = TfLiteTensor* bw_hidden_state_quantized =
GetTemporary(context, node, kBwHiddenStateQuantized); GetTemporary(context, node, kBwHiddenStateQuantized);
bw_hidden_state_quantized->type = fw_input_weights->type; bw_hidden_state_quantized->type = fw_input_weights->type;
@ -227,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store scaling factors of quantization. // Allocate temporary tensors to store scaling factors of quantization.
node->temporaries->data[kScalingFactors] = node->temporaries->data[kScalingFactors] =
op_data->scratch_tensor_index + kScalingFactors; *scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors); GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
@ -239,66 +223,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[kAccumScratch] =
op_data->scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = accum_scratch_dims[0];
accum_scratch_size->data[1] = accum_scratch_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points =
GetTemporary(context, node, /*index=*/kZeroPoints);
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
const int num_row_sums = has_aux_input ? 3 : 2;
node->temporaries->data[kFwRowSums] =
op_data->scratch_tensor_index + kFwRowSums;
TfLiteTensor* fw_row_sums =
GetTemporary(context, node, /*index=*/kFwRowSums);
fw_row_sums->type = kTfLiteInt32;
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
fw_row_sums_size->data[0] = fw_row_sums_dims[0];
fw_row_sums_size->data[1] = fw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
fw_row_sums_size));
}
node->temporaries->data[kBwRowSums] =
op_data->scratch_tensor_index + kBwRowSums;
TfLiteTensor* bw_row_sums = GetTemporary(context, node,
/*index=*/kBwRowSums);
bw_row_sums->type = kTfLiteInt32;
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
bw_row_sums_size));
}
if (has_aux_input) { if (has_aux_input) {
node->temporaries->data[kAuxInputQuantized] = node->temporaries->data[kAuxInputQuantized] =
op_data->scratch_tensor_index + kAuxInputQuantized; *scratch_tensor_index + kAuxInputQuantized;
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
GetTemporary(context, node, kAuxInputQuantized); GetTemporary(context, node, kAuxInputQuantized);
aux_input_quantized->type = fw_input_weights->type; aux_input_quantized->type = fw_input_weights->type;
@ -490,10 +418,7 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
TfLiteTensor* bw_output, TfLiteTensor* zero_points, TfLiteTensor* bw_output) {
TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
bool* bw_compute_row_sums) {
const bool time_major = params->time_major; const bool time_major = params->time_major;
const int batch_size = const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0]; (time_major) ? input->dims->data[1] : input->dims->data[0];
@ -539,20 +464,11 @@ TfLiteStatus EvalHybrid(
int8_t* bw_quantized_hidden_state_ptr = int8_t* bw_quantized_hidden_state_ptr =
GetTensorData<int8_t>(bw_hidden_state_quantized); GetTensorData<int8_t>(bw_hidden_state_quantized);
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
int32_t* zero_points_ptr = nullptr;
int32_t* fw_row_sums_ptr = nullptr;
int32_t* bw_row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
}
const int fw_output_step = const int fw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
const int bw_output_step = const int bw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
if (time_major) { if (time_major) {
for (int t = 0; t < max_time; t++) { for (int t = 0; t < max_time; t++) {
// Forward cell. // Forward cell.
@ -575,9 +491,7 @@ TfLiteStatus EvalHybrid(
fw_num_units, batch_size, fw_output_step, params->activation, fw_num_units, batch_size, fw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
fw_quantized_hidden_state_ptr, scaling_factors_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr,
fw_hidden_state_ptr_batch, output_ptr_batch, fw_hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
} }
// Backward cell. // Backward cell.
float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state); float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
@ -602,9 +516,7 @@ TfLiteStatus EvalHybrid(
bw_num_units, batch_size, bw_output_step, params->activation, bw_num_units, batch_size, bw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
bw_quantized_hidden_state_ptr, scaling_factors_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr,
bw_hidden_state_ptr_batch, output_ptr_batch, bw_hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
} }
} }
} else { } else {
@ -633,9 +545,7 @@ TfLiteStatus EvalHybrid(
fw_num_units, /*batch_size=*/1, fw_output_step, params->activation, fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
fw_quantized_hidden_state_ptr, scaling_factors_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr,
fw_hidden_state_ptr_batch, output_ptr_batch, fw_hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
} }
// Backward cell. // Backward cell.
float* bw_hidden_state_ptr_batch = float* bw_hidden_state_ptr_batch =
@ -664,9 +574,7 @@ TfLiteStatus EvalHybrid(
bw_num_units, /*batch_size=*/1, bw_output_step, params->activation, bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
quantized_input_ptr, aux_quantized_input_ptr, quantized_input_ptr, aux_quantized_input_ptr,
bw_quantized_hidden_state_ptr, scaling_factors_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr,
bw_hidden_state_ptr_batch, output_ptr_batch, bw_hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
} }
} }
} }
@ -748,23 +656,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, kBwHiddenStateQuantized); GetTemporary(context, node, kBwHiddenStateQuantized);
TfLiteTensor* scaling_factors = TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors); GetTemporary(context, node, kScalingFactors);
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
TfLiteTensor* aux_input_quantized = TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr; : nullptr;
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
return EvalHybrid( return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights,
input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias, fw_bias, bw_input_weights, bw_recurrent_weights,
bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input, bw_bias, real_aux_input, fw_aux_input_weights,
fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors, bw_aux_input_weights, params, scaling_factors,
input_quantized, aux_input_quantized, fw_hidden_state_quantized, input_quantized, aux_input_quantized,
fw_hidden_state, fw_output, bw_hidden_state_quantized, fw_hidden_state_quantized, fw_hidden_state, fw_output,
bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums, bw_hidden_state_quantized, bw_hidden_state, bw_output);
bw_row_sums, &op_data->fw_compute_row_sums,
&op_data->bw_compute_row_sums);
} }
default: default:
context->ReportError(context, "Type not currently supported."); context->ReportError(context, "Type not currently supported.");

View File

@ -662,24 +662,20 @@ class BidirectionalRNNOpModel : public SingleOpModel {
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
int bw_units, int input_size, int aux_input_size, int bw_units, int input_size, int aux_input_size,
AuxInputMode aux_input_mode, bool time_major, AuxInputMode aux_input_mode, bool time_major,
bool merge_outputs, bool quantize_weights = false, bool merge_outputs)
bool asymmetric_quantize_weights = false)
: batches_(batches), : batches_(batches),
sequence_len_(sequence_len), sequence_len_(sequence_len),
fw_units_(fw_units), fw_units_(fw_units),
bw_units_(bw_units), bw_units_(bw_units),
input_size_(input_size), input_size_(input_size),
aux_input_size_(aux_input_size), aux_input_size_(aux_input_size) {
quantize_weights_(quantize_weights) {
const TensorType tensor_type =
quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
fw_weights_ = AddInput(tensor_type); fw_weights_ = AddInput(TensorType_FLOAT32);
fw_recurrent_weights_ = AddInput(tensor_type); fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32); fw_bias_ = AddInput(TensorType_FLOAT32);
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true); fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_weights_ = AddInput(tensor_type); bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(tensor_type); bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32); bw_bias_ = AddInput(TensorType_FLOAT32);
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
@ -701,8 +697,8 @@ class BidirectionalRNNOpModel : public SingleOpModel {
} }
if (aux_input_mode == AuxInputMode::kCrossLinking) { if (aux_input_mode == AuxInputMode::kCrossLinking) {
aux_fw_weights_ = AddInput(tensor_type); aux_fw_weights_ = AddInput(TensorType_FLOAT32);
aux_bw_weights_ = AddInput(tensor_type); aux_bw_weights_ = AddInput(TensorType_FLOAT32);
aux_fw_weights_shape = {fw_units, aux_input_size_}; aux_fw_weights_shape = {fw_units, aux_input_size_};
aux_bw_weights_shape = {bw_units, aux_input_size_}; aux_bw_weights_shape = {bw_units, aux_input_size_};
@ -716,11 +712,11 @@ class BidirectionalRNNOpModel : public SingleOpModel {
bw_output_ = AddOutput(TensorType_FLOAT32); bw_output_ = AddOutput(TensorType_FLOAT32);
} }
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, SetBuiltinOp(
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_BidirectionalSequenceRNNOptions, BuiltinOptions_BidirectionalSequenceRNNOptions,
CreateBidirectionalSequenceRNNOptions( CreateBidirectionalSequenceRNNOptions(
builder_, time_major, ActivationFunctionType_RELU, builder_, time_major, ActivationFunctionType_RELU, merge_outputs)
merge_outputs, asymmetric_quantize_weights)
.Union()); .Union());
BuildInterpreter({ BuildInterpreter({
@ -748,36 +744,20 @@ class BidirectionalRNNOpModel : public SingleOpModel {
} }
void SetFwWeights(const std::vector<float>& f) { void SetFwWeights(const std::vector<float>& f) {
if (quantize_weights_) {
SymmetricQuantizeAndPopulate(fw_weights_, f);
} else {
PopulateTensor(fw_weights_, f); PopulateTensor(fw_weights_, f);
} }
}
void SetBwWeights(const std::vector<float>& f) { void SetBwWeights(const std::vector<float>& f) {
if (quantize_weights_) {
SymmetricQuantizeAndPopulate(bw_weights_, f);
} else {
PopulateTensor(bw_weights_, f); PopulateTensor(bw_weights_, f);
} }
}
void SetFwRecurrentWeights(const std::vector<float>& f) { void SetFwRecurrentWeights(const std::vector<float>& f) {
if (quantize_weights_) {
SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
} else {
PopulateTensor(fw_recurrent_weights_, f); PopulateTensor(fw_recurrent_weights_, f);
} }
}
void SetBwRecurrentWeights(const std::vector<float>& f) { void SetBwRecurrentWeights(const std::vector<float>& f) {
if (quantize_weights_) {
SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
} else {
PopulateTensor(bw_recurrent_weights_, f); PopulateTensor(bw_recurrent_weights_, f);
} }
}
void SetInput(std::initializer_list<float> data) { void SetInput(std::initializer_list<float> data) {
PopulateTensor(input_, data); PopulateTensor(input_, data);
@ -792,20 +772,12 @@ class BidirectionalRNNOpModel : public SingleOpModel {
} }
void SetAuxFwWeights(const std::vector<float>& f) { void SetAuxFwWeights(const std::vector<float>& f) {
if (quantize_weights_) {
SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
} else {
PopulateTensor(aux_fw_weights_, f); PopulateTensor(aux_fw_weights_, f);
} }
}
void SetAuxBwWeights(const std::vector<float>& f) { void SetAuxBwWeights(const std::vector<float>& f) {
if (quantize_weights_) {
SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
} else {
PopulateTensor(aux_bw_weights_, f); PopulateTensor(aux_bw_weights_, f);
} }
}
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); } std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); } std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
@ -839,31 +811,17 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_units_; int bw_units_;
int input_size_; int input_size_;
int aux_input_size_; int aux_input_size_;
bool quantize_weights_;
}; };
// Declare LSTMOpTest as a parameterized test.
class BidirectionalRNNOpTest
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
::testing::Combine(
/*quantize_weights*/ ::testing::Bool(),
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
// TODO(mirkov): add another test which directly compares to TF once TOCO // TODO(mirkov): add another test which directly compares to TF once TOCO
// supports the conversion from dynamic_rnn with BasicRNNCell. // supports the conversion from dynamic_rnn with BasicRNNCell.
TEST_P(BidirectionalRNNOpTest, BlackBoxTest) { TEST(BidirectionalRNNOpTest, BlackBoxTest) {
auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16, /*fw_units=*/16, /*bw_units=*/16,
/*input_size=*/8, /*aux_input_size=*/0, /*input_size=*/8, /*aux_input_size=*/0,
/*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
/*time_major=*/false, /*time_major=*/false,
/*merge_outputs=*/false, quantize_weights, /*merge_outputs=*/false);
asymmetric_quantize_inputs);
rnn.SetFwWeights(weights); rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights); rnn.SetBwWeights(weights);
rnn.SetFwBias(biases); rnn.SetFwBias(biases);
@ -885,9 +843,7 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
std::vector<float> fw_expected; std::vector<float> fw_expected;
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
EXPECT_THAT(rnn.GetFwOutput(), EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
ElementsAreArray(ArrayFloatNear(
fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
float* golden_bw_start = rnn_golden_bw_output; float* golden_bw_start = rnn_golden_bw_output;
float* golden_bw_end = float* golden_bw_end =
@ -895,23 +851,17 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
std::vector<float> bw_expected; std::vector<float> bw_expected;
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
EXPECT_THAT(rnn.GetBwOutput(), EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
ElementsAreArray(ArrayFloatNear(
bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
} }
// Same as BlackBox test, but input is reshuffled to time_major format. // Same as BlackBox test, but input is reshuffled to time_major format.
TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16, /*fw_units=*/16, /*bw_units=*/16,
/*input_size=*/8, /*aux_input_size=*/0, /*input_size=*/8, /*aux_input_size=*/0,
/*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
/*time_major=*/true, /*time_major=*/true,
/*merge_outputs=*/false, quantize_weights, /*merge_outputs=*/false);
asymmetric_quantize_inputs);
rnn.SetFwWeights(weights); rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights); rnn.SetBwWeights(weights);
rnn.SetFwBias(biases); rnn.SetFwBias(biases);
@ -939,26 +889,17 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
} }
constexpr float kHybridTolerance = 3.57e-1; EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
constexpr float kFloatTolerance = 1e-5;
EXPECT_THAT(
rnn.GetFwOutput(),
ElementsAreArray(ArrayFloatNear(
fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
} }
// Same as BlackBox test, yet with merged outputs. // Same as BlackBox test, yet with merged outputs.
TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
auto params = GetParam();
const bool quantize_weights = std::get<0>(params);
const bool asymmetric_quantize_inputs = std::get<1>(params);
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16, /*fw_units=*/16, /*bw_units=*/16,
/*input_size=*/8, /*aux_input_size=*/0, /*input_size=*/8, /*aux_input_size=*/0,
/*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*aux_input_mode=*/AuxInputMode::kNoAuxInput,
/*time_major=*/false, /*time_major=*/false,
/*merge_outputs=*/true, quantize_weights, /*merge_outputs=*/true);
asymmetric_quantize_inputs);
rnn.SetFwWeights(weights); rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights); rnn.SetBwWeights(weights);
rnn.SetFwBias(biases); rnn.SetFwBias(biases);
@ -988,8 +929,7 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
} }
} }
EXPECT_THAT(rnn.GetFwOutput(), EXPECT_THAT(rnn.GetFwOutput(),
ElementsAreArray(ArrayFloatNear( ElementsAreArray(ArrayFloatNear(merged_expected)));
merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
} }
// Same as BlackBox test, but input is reshuffled to time_major format. // Same as BlackBox test, but input is reshuffled to time_major format.

View File

@ -71,7 +71,6 @@ struct OpData {
int32_t output_activation_max; int32_t output_activation_max;
// The index of the temporary tensor where the quantized inputs are cached. // The index of the temporary tensor where the quantized inputs are cached.
int scratch_tensor_index; int scratch_tensor_index;
bool compute_row_sums = false;
}; };
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
@ -132,7 +131,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to carry information from Prepare() to // Instead, we allocate a new object to carry information from Prepare() to
// Eval(). // Eval().
auto* op_data = new OpData(); auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/5, context->AddTensors(context, /*tensors_to_add=*/3,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
return op_data; return op_data;
} }
@ -145,6 +144,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data); OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Check we have all the inputs and outputs we need. // Check we have all the inputs and outputs we need.
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3); TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
// Shuffled formats need a workspace to store the shuffled input activations. // Shuffled formats need a workspace to store the shuffled input activations.
@ -208,8 +208,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && if (input->type == kTfLiteFloat32 &&
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) { (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
data->compute_row_sums = true; node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries = TfLiteIntArrayCreate(5);
node->temporaries->data[0] = data->scratch_tensor_index; node->temporaries->data[0] = data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
@ -246,28 +245,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
node->temporaries->data[3] = data->scratch_tensor_index + 3;
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
input_offsets->type = kTfLiteInt32;
input_offsets->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
input_offsets_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
input_offsets_size));
}
node->temporaries->data[4] = data->scratch_tensor_index + 4;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_units};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
row_sums_size->data[0] = row_sums_dims[0];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
// Resize output. // Resize output.
@ -355,9 +332,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data, TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* input_quantized, const TfLiteTensor* bias, TfLiteTensor* input_quantized,
TfLiteTensor* scaling_factors, TfLiteTensor* scaling_factors, TfLiteTensor* output) {
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
TfLiteTensor* input_offsets, TfLiteTensor* output) {
int total_input_size = 1; int total_input_size = 1;
for (int i = 0; i < input->dims->size; i++) { for (int i = 0; i < input->dims->size; i++) {
total_input_size *= input->dims->data[i]; total_input_size *= input->dims->data[i];
@ -388,39 +363,32 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
// Quantize input from float to uint8 + quantization params (scaling factor). // Quantize input from float to uint8 + quantization params (scaling factor).
float unused_min, unused_max; float unused_min, unused_max;
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* input_offset_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
input_offset_ptr = GetTensorData<int32_t>(input_offsets);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
int8_t* quant_data = GetTensorData<int8_t>(input_quantized); int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
const int8_t* filter_data = GetTensorData<int8_t>(filter); const int8_t* filter_data = GetTensorData<int8_t>(filter);
const float* input_ptr = GetTensorData<float>(input);
// Quantize each batch independently. // Quantize each batch independently.
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
if (params->asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
input_ptr + offset, input_size, quant_data + offset,
&scaling_factors_ptr[b], &input_offset_ptr[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, input_size, quant_data + offset, &unused_min, GetTensorData<float>(input) + offset, input_size, quant_data + offset,
&unused_max, &scaling_factors_ptr[b]); &unused_min, &unused_max, &scaling_factors_ptr[b]);
}
// Incorporate scaling of the filter. // Incorporate scaling of the filter.
scaling_factors_ptr[b] *= filter->params.scale; scaling_factors_ptr[b] *= filter->params.scale;
} }
// Compute output += weight * quantized_input // Compute output += weight * quantized_input
#ifdef TFLITE_WITH_RUY_GEMV
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
int32_t* scratch = GetTensorData<int32_t>(accum_scratch); int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr, filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr, batch_size, scratch, GetTensorData<float>(output),
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, GetTensorData<float>(output));
#endif
// Apply activation function to floats. // Apply activation function to floats.
tensor_utils::ApplyActivationToVector( tensor_utils::ApplyActivationToVector(
GetTensorData<float>(output), batch_size * num_units, params->activation, GetTensorData<float>(output), batch_size * num_units, params->activation,
@ -493,12 +461,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
if (input->type == kTfLiteFloat32) { if (input->type == kTfLiteFloat32) {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
return EvalHybrid(context, node, params, data, input, filter, bias, return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, accum_scratch, row_sums, input_quantized, scaling_factors, output);
input_offsets, output);
} else { } else {
FullyConnectedParams op_params; FullyConnectedParams op_params;
op_params.input_offset = input_offset; op_params.input_offset = input_offset;
@ -626,6 +590,7 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
FullyConnectedParams op_params; FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min; op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max; op_params.float_activation_max = output_activation_max;
reference_ops::FullyConnected( reference_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<float>(input), op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter), GetTensorShape(filter), GetTensorData<float>(filter),

View File

@ -286,8 +286,7 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
public: public:
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input, HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
const TensorData& weights, const TensorData& weights,
const TensorData& output = {TensorType_FLOAT32}, const TensorData& output = {TensorType_FLOAT32})
bool asymmetric_inputs = false)
: batches_(batches), units_(units) { : batches_(batches), units_(units) {
int total_input_size = 1; int total_input_size = 1;
for (size_t i = 0; i < input.shape.size(); ++i) { for (size_t i = 0; i < input.shape.size(); ++i) {
@ -303,13 +302,10 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
output_ = AddOutput(output); output_ = AddOutput(output);
auto options = CreateFullyConnectedOptions( SetBuiltinOp(
builder_, ActivationFunctionType_RELU, BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
false, asymmetric_inputs) .Union());
.Union();
SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
BuiltinOptions_FullyConnectedOptions, options);
resolver_ = absl::make_unique<SingleOpResolver>( resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, BuiltinOperator_FULLY_CONNECTED,
ops::builtin::Register_FULLY_CONNECTED_PIE()); ops::builtin::Register_FULLY_CONNECTED_PIE());
@ -871,66 +867,6 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
/*max_abs_error=*/1.3f))); /*max_abs_error=*/1.3f)));
} }
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}},
/*weights=*/
{TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
/*asymmetric_quantize_input*/ true); // Hybrid asymmetric
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetBias({1, 2, 3});
m.SetInput({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, 25, 26, //
58, 59, 60, //
},
/*max_abs_error=*/0.64f)));
}
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
HybridFullyConnectedOpModel m(
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 10}},
/*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
{TensorType_FLOAT32},
/*asymmetric_quantize_input*/ true);
m.SetSignedWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetBias({1, 2, 3});
m.SetInput({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, 25, 26, //
58, 59, 60, //
},
/*max_abs_error=*/1.3f)));
}
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of // Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in // batches. All we care is that the input can be evenly distributed in

View File

@ -123,9 +123,7 @@ void RnnBatchStep(
int num_units, int batch_size, int output_batch_leading_dim, int num_units, int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch, float* hidden_state_ptr_batch, float* output_ptr_batch) {
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale, RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
/*aux_input_ptr_batch=*/nullptr, /*aux_input_ptr_batch=*/nullptr,
/*aux_input_weights_ptr=*/nullptr, /*aux_input_weights_ptr=*/nullptr,
@ -135,29 +133,7 @@ void RnnBatchStep(
output_batch_leading_dim, activation, quantized_input_ptr_batch, output_batch_leading_dim, activation, quantized_input_ptr_batch,
/*aux_quantized_input_ptr_batch=*/nullptr, /*aux_quantized_input_ptr_batch=*/nullptr,
quantized_hidden_state_ptr_batch, scaling_factors, quantized_hidden_state_ptr_batch, scaling_factors,
hidden_state_ptr_batch, output_ptr_batch, hidden_state_ptr_batch, output_ptr_batch);
asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums,
compute_row_sums);
}
void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums,
int32_t* recurrent_row_sums, int32_t* row_sums,
const float* aux_input_ptr_batch, int num_units,
int input_size, int aux_input_size,
const int8_t* input_weights_ptr,
const int8_t* aux_input_weights_ptr,
const int8_t* recurrent_weights_ptr) {
memset(input_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units,
input_size);
if (aux_input_ptr_batch) {
memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums,
num_units, aux_input_size);
}
memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums,
num_units, num_units);
} }
void RnnBatchStep( void RnnBatchStep(
@ -170,31 +146,9 @@ void RnnBatchStep(
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* aux_quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch, float* hidden_state_ptr_batch, float* output_ptr_batch) {
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
// Since the output batch rows may not be contiguous (output_batch_leading_dim // Since the output batch rows may not be contiguous (output_batch_leading_dim
// != n_output), we unroll the batched operations where this is the case. // != n_output), we unroll the batched operations where this is the case.
int32_t* input_row_sums = nullptr;
int32_t* aux_input_row_sums = nullptr;
int32_t* recurrent_row_sums = nullptr;
if (asymmetric_quantize_inputs) {
input_row_sums = row_sums;
aux_input_row_sums = row_sums;
if (aux_input_ptr_batch) {
aux_input_row_sums += num_units;
}
recurrent_row_sums = aux_input_row_sums + num_units;
if (*compute_row_sums) {
ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums,
row_sums, aux_input_ptr_batch, num_units, input_size,
aux_input_size, input_weights_ptr,
aux_input_weights_ptr, recurrent_weights_ptr);
*compute_row_sums = false;
}
}
if (output_batch_leading_dim == num_units) { if (output_batch_leading_dim == num_units) {
// Output = bias // Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
@ -209,25 +163,17 @@ void RnnBatchStep(
// whichever is faster. // whichever is faster.
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
input_ptr_batch + offset, input_size,
quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
input_ptr_batch + offset, input_size, input_ptr_batch + offset, input_size,
quantized_input_ptr_batch + offset, &unused_min, &unused_max, quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]); &scaling_factors[b]);
}
scaling_factors[b] *= input_weights_scale; scaling_factors[b] *= input_weights_scale;
} }
// Output += input * input_weights // Output += input * input_weights
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
scaling_factors, batch_size, output_ptr_batch, scaling_factors, batch_size, output_ptr_batch);
/*per_channel_scale=*/nullptr, zero_points, accum_scratch,
input_row_sums, compute_row_sums, /*context=*/nullptr);
} }
if (aux_input_ptr_batch && if (aux_input_ptr_batch &&
@ -236,17 +182,10 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * aux_input_size; const int offset = b * aux_input_size;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
aux_input_ptr_batch + offset, aux_input_size,
aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr_batch + offset, aux_input_size, aux_input_ptr_batch + offset, aux_input_size,
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]); &scaling_factors[b]);
}
scaling_factors[b] *= aux_input_weights_scale; scaling_factors[b] *= aux_input_weights_scale;
} }
@ -254,9 +193,7 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_weights_ptr, num_units, aux_input_size, aux_input_weights_ptr, num_units, aux_input_size,
aux_quantized_input_ptr_batch, scaling_factors, batch_size, aux_quantized_input_ptr_batch, scaling_factors, batch_size,
output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points, output_ptr_batch);
accum_scratch, aux_input_row_sums, compute_row_sums,
/*context=*/nullptr);
} }
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
@ -266,17 +203,10 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * num_units; const int offset = b * num_units;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
hidden_state_ptr_batch + offset, num_units,
quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
hidden_state_ptr_batch + offset, num_units, hidden_state_ptr_batch + offset, num_units,
quantized_hidden_state_ptr_batch + offset, &unused_min, quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
&unused_max, &scaling_factors[b]); &scaling_factors[b]);
}
scaling_factors[b] *= recurrent_weights_scale; scaling_factors[b] *= recurrent_weights_scale;
} }
@ -284,9 +214,7 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, recurrent_weights_ptr, num_units, num_units,
quantized_hidden_state_ptr_batch, scaling_factors, batch_size, quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points, output_ptr_batch);
accum_scratch, recurrent_row_sums, compute_row_sums,
/*context=*/nullptr);
} }
// Output = activation(Output) and update hidden_state // Output = activation(Output) and update hidden_state
@ -310,17 +238,10 @@ void RnnBatchStep(
// whichever is faster. // whichever is faster.
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
input_ptr_batch + offset, input_size,
quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
input_ptr_batch + offset, input_size, input_ptr_batch + offset, input_size,
quantized_input_ptr_batch + offset, &unused_min, &unused_max, quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]); &scaling_factors[b]);
}
scaling_factors[b] *= input_weights_scale; scaling_factors[b] *= input_weights_scale;
} }
@ -329,9 +250,7 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, input_weights_ptr, num_units, input_size,
quantized_input_ptr_batch + k * input_size, &scaling_factors[k], quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim, /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
input_row_sums, compute_row_sums, /*context=*/nullptr);
} }
} }
@ -341,17 +260,10 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * aux_input_size; const int offset = b * aux_input_size;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
aux_input_ptr_batch + offset, aux_input_size,
aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr_batch + offset, aux_input_size, aux_input_ptr_batch + offset, aux_input_size,
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]); &scaling_factors[b]);
}
scaling_factors[b] *= aux_input_weights_scale; scaling_factors[b] *= aux_input_weights_scale;
} }
@ -361,9 +273,7 @@ void RnnBatchStep(
aux_input_weights_ptr, num_units, aux_input_size, aux_input_weights_ptr, num_units, aux_input_size,
aux_quantized_input_ptr_batch + k * aux_input_size, aux_quantized_input_ptr_batch + k * aux_input_size,
&scaling_factors[k], &scaling_factors[k],
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim, /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
aux_input_row_sums, compute_row_sums, /*context=*/nullptr);
} }
} }
@ -374,17 +284,10 @@ void RnnBatchStep(
float unused_min, unused_max; float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * num_units; const int offset = b * num_units;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
hidden_state_ptr_batch + offset, num_units,
quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
&zero_points[b]);
} else {
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
hidden_state_ptr_batch + offset, num_units, hidden_state_ptr_batch + offset, num_units,
quantized_hidden_state_ptr_batch + offset, &unused_min, quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
&unused_max, &scaling_factors[b]); &scaling_factors[b]);
}
scaling_factors[b] *= recurrent_weights_scale; scaling_factors[b] *= recurrent_weights_scale;
} }
@ -393,10 +296,8 @@ void RnnBatchStep(
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, recurrent_weights_ptr, num_units, num_units,
quantized_hidden_state_ptr_batch + k * num_units, quantized_hidden_state_ptr_batch + k * num_units,
&scaling_factors[k], /*n_batch=*/1, &scaling_factors[k],
output_ptr_batch + k * output_batch_leading_dim, /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
recurrent_row_sums, compute_row_sums, /*context=*/nullptr);
} }
} }

View File

@ -70,9 +70,7 @@ void RnnBatchStep(
int num_units, int batch_size, int output_batch_leading_dim, int num_units, int batch_size, int output_batch_leading_dim,
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch, float* hidden_state_ptr_batch, float* output_ptr_batch);
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
void RnnBatchStep( void RnnBatchStep(
const float* input_ptr_batch, const int8_t* input_weights_ptr, const float* input_ptr_batch, const int8_t* input_weights_ptr,
@ -84,9 +82,7 @@ void RnnBatchStep(
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
int8_t* aux_quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch, float* hidden_state_ptr_batch, float* output_ptr_batch);
bool asymmetric_quantize_inputs, int32_t* zero_points,
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
} // namespace kernel_utils } // namespace kernel_utils
} // namespace tflite } // namespace tflite

View File

@ -1310,13 +1310,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1); const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1); const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
int32_t* row_sums_ptr = row_sums;
if (row_sums == nullptr) {
row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows);
NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
}
for (int batch = 0; batch < n_batch; ++batch) { for (int batch = 0; batch < n_batch; ++batch) {
const float batch_scaling_factor = scaling_factors[batch]; const float batch_scaling_factor = scaling_factors[batch];
const int batch_input_offset = input_offset[batch]; const int batch_input_offset = input_offset[batch];
@ -1334,6 +1327,10 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
// Initialize the dot product sum for the row to 0. // Initialize the dot product sum for the row to 0.
int32x4_t dotprod_32x4 = vmovq_n_s32(0); int32x4_t dotprod_32x4 = vmovq_n_s32(0);
int32x4_t row_sum_32x4;
if (row_sums == nullptr) {
row_sum_32x4 = vmovq_n_s32(0);
}
// Prefetch the row to cache. // Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */, __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */); 3 /* temporal locality */);
@ -1361,6 +1358,10 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
prod_16x8 = prod_16x8 =
vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16)); vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
if (row_sums == nullptr) {
const int16x8_t row_sum_16x8 = vpaddlq_s8(s2_8x16);
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
}
} // for col } // for col
// Half iteration dealing only 8 elements // Half iteration dealing only 8 elements
@ -1374,24 +1375,29 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col)); const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8); const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
if (row_sums == nullptr) {
const int16x8_t row_sum_16x8 = vmovl_s8(s2_8x8);
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
}
col += (kWeightsPerNeonLane >> 1); col += (kWeightsPerNeonLane >> 1);
} }
int32_t dotprod = AccumulateNeonLane(dotprod_32x4); int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
int32_t row_sum = row_sums == nullptr ? AccumulateNeonLane(row_sum_32x4)
: row_sums[row];
// Postamble loop. // Postamble loop.
for (; col < m_cols; ++col) { for (; col < m_cols; ++col) {
dotprod += row_ptr[col] * aligned_vec[col]; dotprod += row_ptr[col] * aligned_vec[col];
if (row_sums == nullptr) {
row_sum += row_ptr[col];
}
} // for col } // for col
dotprod -= row_sums_ptr[row] * batch_input_offset; dotprod -= row_sum * batch_input_offset;
*result += dotprod * scale; *result += dotprod * scale;
++result; ++result;
} // for row } // for row
} // for batch } // for batch
if (row_sums == nullptr) {
free(row_sums_ptr);
}
if (unaligned) { if (unaligned) {
free(aligned_row_free); free(aligned_row_free);
} }
@ -1404,20 +1410,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, const float* per_channel_scale, int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) { bool* compute_row_sums, CpuBackendContext* context) {
if (input_offset == nullptr) {
#ifdef TFLITE_WITH_RUY_GEMV
if (context) {
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, scratch,
result, context);
return;
}
#endif
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, result);
return;
}
if (compute_row_sums == nullptr || *compute_row_sums) { if (compute_row_sums == nullptr || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * m_rows); memset(row_sums, 0, sizeof(int32_t) * m_rows);
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols); NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
@ -1427,7 +1419,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
} }
#ifdef TFLITE_WITH_RUY_GEMV #ifdef TFLITE_WITH_RUY_GEMV
if (context != nullptr && m_rows % 4 == 0) { if (m_rows % 4 == 0) {
const int32_t* bias = static_cast<const int32_t*>(nullptr); const int32_t* bias = static_cast<const int32_t*>(nullptr);
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0, NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
scratch, context); scratch, context);
@ -1471,9 +1463,9 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
for (; i < total_size; i++) { for (; i < total_size; i++) {
const float batch_scaling_factor = scaling_factors[i / m_rows]; const float batch_scaling_factor = scaling_factors[i / m_rows];
const int32_t zero_point = input_offset[i / m_rows]; const int32_t zero_point = input_offset[i / m_rows];
int32_t dotprod = *(scratch_ptr++); int32_t x = *(scratch_ptr++);
dotprod -= row_sums[i % m_rows] * zero_point; x -= row_sums[i % m_rows] * zero_point;
*result += dotprod * batch_scaling_factor; *result += x * batch_scaling_factor;
++result; ++result;
} }
return; return;

View File

@ -167,11 +167,6 @@ void SseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ scaling_factors, int n_batch, const float* __restrict__ scaling_factors, int n_batch,
float* __restrict__ result, const float* __restrict__ per_channel_scale, float* __restrict__ result, const float* __restrict__ per_channel_scale,
const int32_t* __restrict__ input_offset) { const int32_t* __restrict__ input_offset) {
if (input_offset == nullptr) {
SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, result);
return;
}
static constexpr std::intptr_t kBlockSize = 16; static constexpr std::intptr_t kBlockSize = 16;
for (std::intptr_t batch = 0; batch < n_batch; ++batch) { for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
const float batch_scaling_factor = scaling_factors[batch]; const float batch_scaling_factor = scaling_factors[batch];

View File

@ -196,11 +196,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, const float* per_channel_scale, int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) { bool* compute_row_sums, CpuBackendContext* context) {
if (input_offset == nullptr) {
PortableMatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
return;
}
if (!compute_row_sums || *compute_row_sums) { if (!compute_row_sums || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * m_rows); memset(row_sums, 0, sizeof(int32_t) * m_rows);
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols); PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);

View File

@ -223,8 +223,7 @@ inline void EvalHybridSVDF(
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
const TfLiteTensor* bias, const TfLiteSVDFParams* params, const TfLiteTensor* bias, const TfLiteSVDFParams* params,
TfLiteTensor* scratch, TfLiteTensor* scaling_factors, TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output, TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) {
const int rank = params->rank; const int rank = params->rank;
const int batch_size = input->dims->data[0]; const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1]; const int input_size = input->dims->data[1];
@ -245,13 +244,6 @@ inline void EvalHybridSVDF(
float* output_ptr = GetTensorData<float>(output); float* output_ptr = GetTensorData<float>(output);
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
// Initialize the weights scale. // Initialize the weights scale.
const float weights_feature_scale = weights_feature->params.scale; const float weights_feature_scale = weights_feature->params.scale;
@ -266,30 +258,21 @@ inline void EvalHybridSVDF(
if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) { if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
// Quantize input from float to int8. // Quantize input from float to int8.
float unused_min, unused_max;
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size; const int offset = b * input_size;
if (params->asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
input_ptr + offset, input_size, quantized_input_ptr + offset,
&scaling_factors_ptr[b], &zero_points_ptr[b]);
} else {
// Quantize input from float to int8.
float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, input_size, quantized_input_ptr + offset, input_ptr + offset, input_size, quantized_input_ptr + offset,
&unused_min, &unused_max, &scaling_factors_ptr[b]); &unused_min, &unused_max, &scaling_factors_ptr[b]);
}
scaling_factors_ptr[b] *= weights_feature_scale; scaling_factors_ptr[b] *= weights_feature_scale;
} }
// Compute conv1d(inputs, weights_feature). // Compute conv1d(inputs, weights_feature).
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
weights_feature_ptr, num_filters, input_size, quantized_input_ptr, weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
scaling_factors_ptr, batch_size, scratch_ptr, scaling_factors_ptr, batch_size, scratch_ptr);
/*per_channel_scale=*/nullptr, zero_points_ptr,
reinterpret_cast<int32_t*>(scratch_ptr), row_sums_ptr, compute_row_sums,
/*context=*/nullptr);
} }
// Copy the latest activation from scratch into activation_state: // Copy the latest activation from scratch into activation_state:
// The last, i.e. (memory_size-1)th entry for each batch, and filter. // The last, i.e. (memory_size-1)th entry for each batch, and filter.
for (int i = 0; i < batch_size * num_filters; ++i) { for (int i = 0; i < batch_size * num_filters; ++i) {

View File

@ -55,7 +55,6 @@ struct OpData {
// These fields are only used by full kernel. // These fields are only used by full kernel.
int scratch_tensor_index; int scratch_tensor_index;
lstm_eval::IntegerLstmParameter integer_lstm_param; lstm_eval::IntegerLstmParameter integer_lstm_param;
bool compute_row_sums;
}; };
namespace full { namespace full {
@ -728,7 +727,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* op_data = new OpData();
op_data->kernel_type = kTfLiteLSTMFullKernel; op_data->kernel_type = kTfLiteLSTMFullKernel;
context->AddTensors(context, /*tensors_to_add=*/10, context->AddTensors(context, /*tensors_to_add=*/8,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
return op_data; return op_data;
} }
@ -1237,7 +1236,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) { if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(10); node->temporaries = TfLiteIntArrayCreate(8);
} else if (is_integer) { } else if (is_integer) {
if (is_8x8_16) { if (is_8x8_16) {
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(6);
@ -1274,7 +1273,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
if (is_hybrid_op) { if (is_hybrid_op) {
op_data->compute_row_sums = true;
// Allocate temporary tensors to store quantized values of input, // Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors. // activation_state and cell_state tensors.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
@ -1372,41 +1370,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
node->temporaries->data[8] = op_data->scratch_tensor_index + 8;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {n_batch};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[9] = op_data->scratch_tensor_index + 9;
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
int row_sums_rows = use_cifg ? 6 : 8;
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
if (projection_weights != nullptr) {
row_sums_rows += ceil(n_output / n_cell);
}
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
const int row_sums_dims[2] = {row_sums_rows, n_cell};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
if (is_integer) { if (is_integer) {
@ -1593,9 +1556,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/6); GetTemporary(context, node, /*index=*/6);
TfLiteTensor* output_scratch_buffer = TfLiteTensor* output_scratch_buffer =
GetTemporary(context, node, /*index=*/7); GetTemporary(context, node, /*index=*/7);
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid( return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -1617,8 +1577,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_quantized, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, cell_state_quantized, activation_state, cell_state,
output_scratch_buffer, output, zero_points, row_sums, row_sums_size, output_scratch_buffer, output,
&op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context)); CpuBackendContext::GetFromContext(context));
} else { } else {
const int num_intermediate_tensors = node->intermediates->size; const int num_intermediate_tensors = node->intermediates->size;

View File

@ -33,95 +33,26 @@ namespace builtin {
namespace lstm_eval { namespace lstm_eval {
namespace { namespace {
void ComputeRowSums(
int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
int n_input, int n_aux_input, int n_output,
const int8_t* input_to_input_weights_ptr,
const int8_t* input_to_forget_weights_ptr,
const int8_t* input_to_cell_weights_ptr,
const int8_t* input_to_output_weights_ptr,
const int8_t* aux_input_to_input_weights_ptr,
const int8_t* aux_input_to_forget_weights_ptr,
const int8_t* aux_input_to_cell_weights_ptr,
const int8_t* aux_input_to_output_weights_ptr,
const int8_t* recurrent_to_input_weights_ptr,
const int8_t* recurrent_to_forget_weights_ptr,
const int8_t* recurrent_to_cell_weights_ptr,
const int8_t* recurrent_to_output_weights_ptr,
const int8_t* projection_weights_ptr, bool use_cifg,
const float* aux_input_ptr) {
// Compute the row sums for dequantization
if (!use_cifg) {
memset(input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
input_to_input_row_sums, n_cell, n_input);
}
memset(input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
input_to_forget_row_sums, n_cell, n_input);
memset(input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
input_to_cell_row_sums, n_cell, n_input);
memset(input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
input_to_output_row_sums, n_cell, n_input);
if (aux_input_ptr) {
if (!use_cifg) {
memset(aux_input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
aux_input_to_input_row_sums, n_cell,
n_aux_input);
}
memset(aux_input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
aux_input_to_forget_row_sums, n_cell,
n_aux_input);
memset(aux_input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
aux_input_to_cell_row_sums, n_cell,
n_aux_input);
memset(aux_input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
aux_input_to_output_row_sums, n_cell,
n_aux_input);
}
if (!use_cifg) {
memset(recurrent_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
recurrent_to_input_row_sums, n_cell,
n_output);
}
memset(recurrent_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
recurrent_to_forget_row_sums, n_cell,
n_output);
memset(recurrent_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
recurrent_to_cell_row_sums, n_cell,
n_output);
memset(recurrent_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
recurrent_to_output_row_sums, n_cell,
n_output);
if (projection_weights_ptr != nullptr) {
memset(projection_weights_row_sums, 0, sizeof(int32_t) * n_output);
tensor_utils::ReductionSumVector(
projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
}
}
inline float GetTensorScale(const TfLiteTensor* tensor) { inline float GetTensorScale(const TfLiteTensor* tensor) {
return tensor == nullptr ? 1.0f : tensor->params.scale; return tensor == nullptr ? 1.0f : tensor->params.scale;
} }
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
CpuBackendContext* context) {
// TODO(b/148289189) Remove when Ruy GEMV is the default.
#ifdef TFLITE_WITH_RUY_GEMV
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch,
result, context);
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
#endif
}
// Performs an LSTM batch inference step for input specified by input_ptr. // Performs an LSTM batch inference step for input specified by input_ptr.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional // biases (*_bias_ptr), and buffers (*_scratch), along with additional
@ -542,8 +473,6 @@ inline void LstmStepHybrid(
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
int32_t* zero_points, int32_t* row_sums, int row_sums_size,
bool* compute_row_sums, bool asymmetric_quantize_inputs,
CpuBackendContext* context) { CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepHybrid"); ruy::profiler::ScopeLabel label("LstmStepHybrid");
// Since we have already checked that weights are all there or none, we // Since we have already checked that weights are all there or none, we
@ -574,131 +503,53 @@ inline void LstmStepHybrid(
output_gate_scratch); output_gate_scratch);
} }
int32_t* input_to_input_row_sums = nullptr; // For each batch and cell: compute input_weight * input.
int32_t* input_to_forget_row_sums = nullptr; // Skip if input is all zeros.
int32_t* input_to_cell_row_sums = nullptr;
int32_t* input_to_output_row_sums = nullptr;
int32_t* aux_input_to_input_row_sums = nullptr;
int32_t* aux_input_to_forget_row_sums = nullptr;
int32_t* aux_input_to_cell_row_sums = nullptr;
int32_t* aux_input_to_output_row_sums = nullptr;
int32_t* recurrent_to_input_row_sums = nullptr;
int32_t* recurrent_to_forget_row_sums = nullptr;
int32_t* recurrent_to_cell_row_sums = nullptr;
int32_t* recurrent_to_output_row_sums = nullptr;
int32_t* projection_weights_row_sums = nullptr;
if (asymmetric_quantize_inputs) {
int num_row_sums = use_cifg ? 6 : 8;
if (aux_input_ptr != nullptr) {
num_row_sums += use_cifg ? 3 : 4;
}
if (projection_weights_ptr != nullptr) {
num_row_sums += ceil(n_output / n_cell);
}
TF_LITE_ASSERT(row_sums_size == num_row_sums);
input_to_input_row_sums = row_sums;
input_to_forget_row_sums =
use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
input_to_output_row_sums = input_to_cell_row_sums + n_cell;
if (aux_input_ptr != nullptr) {
aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
aux_input_to_forget_row_sums = use_cifg
? aux_input_to_input_row_sums
: aux_input_to_input_row_sums + n_cell;
aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
}
recurrent_to_input_row_sums = aux_input_ptr
? aux_input_to_output_row_sums + n_cell
: input_to_output_row_sums + n_cell;
recurrent_to_forget_row_sums = use_cifg
? recurrent_to_input_row_sums
: recurrent_to_input_row_sums + n_cell;
recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
if (projection_weights_ptr != nullptr) {
projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
}
if (*compute_row_sums) {
ComputeRowSums(
input_to_input_row_sums, input_to_forget_row_sums,
input_to_cell_row_sums, input_to_output_row_sums,
aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
input_to_cell_weights_ptr, input_to_output_weights_ptr,
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
projection_weights_ptr, use_cifg, aux_input_ptr);
*compute_row_sums = false;
}
}
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) { if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_input; const int offset = b * n_input;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
input_ptr + offset, n_input, quantized_input_ptr + offset,
&scaling_factors[b], &zero_points[b]);
} else {
float unused_min, unused_max; float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, n_input, quantized_input_ptr + offset, input_ptr + offset, n_input, quantized_input_ptr + offset,
&unused_min, &unused_max, &scaling_factors[b]); &unused_min, &unused_max, &scaling_factors[b]);
} }
}
if (!use_cifg) { if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_input_weights_scale; scaling_factors[b] * input_to_input_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, input_gate_scratch, product_scaling_factors, n_batch, accum_scratch_ptr,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, input_gate_scratch, context);
input_to_input_row_sums, compute_row_sums, context);
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_forget_weights_scale; scaling_factors[b] * input_to_forget_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, forget_gate_scratch, product_scaling_factors, n_batch, accum_scratch_ptr,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, forget_gate_scratch, context);
input_to_forget_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_cell_weights_scale; scaling_factors[b] * input_to_cell_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, cell_scratch, product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, context);
input_to_cell_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_output_weights_scale; scaling_factors[b] * input_to_output_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, output_gate_scratch, product_scaling_factors, n_batch, accum_scratch_ptr,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, output_gate_scratch, context);
input_to_output_row_sums, compute_row_sums, context);
} }
// For each batch and cell: compute aux_input_weight * aux_input. // For each batch and cell: compute aux_input_weight * aux_input.
@ -707,131 +558,98 @@ inline void LstmStepHybrid(
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) { !tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_aux_input; const int offset = b * n_aux_input;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
aux_input_ptr + offset, n_aux_input,
quantized_aux_input_ptr + offset, &scaling_factors[b],
&zero_points[b]);
} else {
float unused_min, unused_max; float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr + offset, n_aux_input, aux_input_ptr + offset, n_aux_input, quantized_aux_input_ptr + offset,
quantized_aux_input_ptr + offset, &unused_min, &unused_max, &unused_min, &unused_max, &scaling_factors[b]);
&scaling_factors[b]);
} }
}
if (!use_cifg) { if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_input_weights_scale; scaling_factors[b] * aux_input_to_input_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_to_input_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, input_gate_scratch, context);
accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums,
context);
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_forget_weights_scale; scaling_factors[b] * aux_input_to_forget_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, forget_gate_scratch, context);
accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums,
context);
row_sums += n_cell;
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_cell_weights_scale; scaling_factors[b] * aux_input_to_cell_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, accum_scratch_ptr, cell_scratch, context);
aux_input_to_cell_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_output_weights_scale; scaling_factors[b] * aux_input_to_output_weights_scale;
} }
MatrixBatchVectorMultiplyAccumulate(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_to_output_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, output_gate_scratch, context);
accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums,
context);
} }
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_output; const int offset = b * n_output;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
output_state_ptr + offset, n_output,
quantized_output_state_ptr + offset, &scaling_factors[b],
&zero_points[b]);
} else {
float unused_min, unused_max; float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
output_state_ptr + offset, n_output, quantized_output_state_ptr + offset,
quantized_output_state_ptr + offset, &unused_min, &unused_max, &unused_min, &unused_max,
&scaling_factors[b]); &scaling_factors[b]);
} }
}
// For each batch and cell: compute recurrent_weight * output_state. // For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) { if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale; scaling_factors[b] * recurrent_to_input_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, input_gate_scratch, context);
accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums,
context);
} }
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_forget_weights_scale; scaling_factors[b] * recurrent_to_forget_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, recurrent_to_forget_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, forget_gate_scratch, context);
accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums,
context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_cell_weights_scale; scaling_factors[b] * recurrent_to_cell_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
cell_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, cell_scratch, context);
accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums,
context);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_output_weights_scale; scaling_factors[b] * recurrent_to_output_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, recurrent_to_output_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, output_gate_scratch, context);
accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums,
context);
} }
// For each batch and cell: update input gate. // For each batch and cell: update input gate.
@ -952,32 +770,22 @@ inline void LstmStepHybrid(
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_cell; const int offset = b * n_cell;
if (asymmetric_quantize_inputs) {
tensor_utils::AsymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &scaling_factors[b],
&zero_points[b]);
} else {
float unused_min, unused_max; float unused_min, unused_max;
tensor_utils::SymmetricQuantizeFloats( tensor_utils::SymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell, output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &unused_min, &unused_max, quantized_cell_state_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]); &scaling_factors[b]);
} }
}
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * projection_weights_scale; scaling_factors[b] * projection_weights_scale;
} }
for (int b = 0; b < n_batch; b++) { for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate( MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b], quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim, /*n_batch=*/1, accum_scratch_ptr,
/*per_channel_scale=*/nullptr, output_ptr + b * output_batch_leading_dim, context);
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,
accum_scratch_ptr, projection_weights_row_sums, compute_row_sums,
context);
} }
} }
if (params->proj_clip > 0.0) { if (params->proj_clip > 0.0) {
@ -1807,8 +1615,7 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, TfLiteTensor* output, CpuBackendContext* context) {
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1]; const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch; int max_time, n_batch;
@ -1847,14 +1654,6 @@ TfLiteStatus EvalHybrid(
const int output_batch_leading_dim = const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1]; output->dims->data[output->dims->size - 1];
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
if (time_major) { if (time_major) {
// Feed the sequence into the LSTM step-by-step. // Feed the sequence into the LSTM step-by-step.
const int input_step = n_batch * n_input; const int input_step = n_batch * n_input;
@ -1922,9 +1721,7 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(output_state_quantized), GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized), GetTensorData<int8_t>(cell_state_quantized),
GetTensorData<float>(output_state), GetTensorData<float>(cell_state), GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
GetTensorData<int32_t>(output_scratch_buffer), output_ptr, GetTensorData<int32_t>(output_scratch_buffer), output_ptr, context);
zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
params->asymmetric_quantize_inputs, context);
} }
} else { } else {
for (int b = 0; b < n_batch; b++) { for (int b = 0; b < n_batch; b++) {
@ -2009,8 +1806,7 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(output_state_quantized), GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr, GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer), cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size, output_ptr, context);
compute_row_sums, params->asymmetric_quantize_inputs, context);
} }
} }
} }

View File

@ -156,8 +156,7 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, TfLiteTensor* output, CpuBackendContext* context);
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_16( TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,

View File

@ -38,8 +38,7 @@ class LSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights, bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip, bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes, const std::vector<std::vector<int>>& input_shapes,
const TensorType weight_type, bool is_layer_norm, const TensorType weight_type, bool is_layer_norm)
bool asymmetric_quantize_inputs = false)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_cell_(n_cell), n_cell_(n_cell),
@ -130,11 +129,9 @@ class LSTMOpModel : public SingleOpModel {
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp( SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, cell_clip, proj_clip)
proj_clip, ::tflite::LSTMKernelType_FULL,
asymmetric_quantize_inputs)
.Union()); .Union());
// Do not apply delegate yet since tensor values are not known (and more // Do not apply delegate yet since tensor values are not known (and more
@ -318,7 +315,7 @@ class LSTMOpModel : public SingleOpModel {
const TensorType weight_type_; const TensorType weight_type_;
}; };
class BaseLstmTest : public ::testing::TestWithParam<bool> { class BaseLstmTest : public ::testing::Test {
protected: protected:
// Weights of the LSTM model. Some are optional. // Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_; std::vector<float> input_to_input_weights_;
@ -568,11 +565,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -610,20 +604,14 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false, GetParam()); /*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651); /*tolerance=*/0.0157651);
} }
class NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
: public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -661,7 +649,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false, GetParam()); /*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651); /*tolerance=*/0.0157651);
@ -757,11 +745,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest, TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -799,18 +784,13 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false, GetParam()); /*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
} }
class CifgNoPeepholeNoProjectionNoClippingLstmInt8Test
: public CifgNoPeepholeNoProjectionNoClippingLstmTest {};
TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test, TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
// n_cell and n_output have the same size when there is no projection. // n_cell and n_output have the same size when there is no projection.
@ -848,7 +828,7 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false, GetParam()); /*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
} }
@ -1494,60 +1474,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest, TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
HybridLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
const int n_output = 16;
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false, GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
class NoCifgPeepholeProjectionNoClippingLstmInt8Test
: public NoCifgPeepholeProjectionNoClippingLstmTest {};
TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
HybridLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
@ -1584,9 +1511,52 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/false, GetParam()); /*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
HybridLstmBlackBoxTestUint8) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
const int n_output = 16;
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
} }
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
@ -1723,11 +1693,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
} }
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
HybridLayerNormLstmBlackBoxTestUint8) { HybridLayerNormLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 4; const int n_cell = 4;
@ -1774,7 +1741,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/true, GetParam()); /*is_layer_norm=*/true);
lstm_golden_output_ = {{ lstm_golden_output_ = {{
// Batch0: 3 (input_sequence_size) * 3 (n_output) // Batch0: 3 (input_sequence_size) * 3 (n_output)
@ -1793,14 +1760,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
/*tolerance=*/0.0010907); /*tolerance=*/0.0010907);
} }
class NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
: public NoCifgPeepholeProjectionNoClippingLayerNormLstmTest {};
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
HybridLayerNormLstmBlackBoxTestInt8) { HybridLayerNormLstmBlackBoxTestInt8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 4; const int n_cell = 4;
@ -1847,24 +1808,22 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/true, GetParam()); /*is_layer_norm=*/true);
// Goldens are calculated from weight_type=TensorType_FLOAT32.
lstm_golden_output_ = {{ lstm_golden_output_ = {{
// Batch0: 3 (input_sequence_size) * 3 (n_output) // Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0244077, 0.128027, -0.00170918, // seq 0 0.0244576, 0.127847, -0.00181765, // seq 0
0.0137642, 0.140751, 0.0395835, // seq 1 0.0137518, 0.140892, 0.0402234, // seq 1
-0.00459233, 0.155278, 0.0837378, // seq 2 -0.0048839, 0.155096, 0.0840309, // seq 2
}, },
{ {
// Batch1: 3 (input_sequence_size) * 3 (n_output) // Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.00692428, 0.0848741, 0.063445, // seq 0 -0.00728636, 0.0843957, 0.0634786, // seq 0
-0.00403911, 0.139963, 0.072681, // seq 1 -0.00448382, 0.139278, 0.0737372, // seq 1
0.00752708, 0.161903, 0.0561371, // seq 2 0.00734616, 0.161793, 0.0560238, // seq 2
}}; }};
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
/*tolerance=*/1.06e-3);
} }
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
@ -1981,11 +1940,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
} }
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest, TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
HybridLayerNormLstmBlackBoxTestUint8) { HybridLayerNormLstmBlackBoxTestUint8) {
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 4; const int n_cell = 4;
@ -2032,7 +1988,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_UINT8, /*weight_type=*/TensorType_UINT8,
/*is_layer_norm=*/true, GetParam()); /*is_layer_norm=*/true);
// Verify the final output. // Verify the final output.
lstm_golden_output_ = { lstm_golden_output_ = {
@ -2053,10 +2009,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
/*tolerance=*/0.000902065); /*tolerance=*/0.000902065);
} }
class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
: public CifgPeepholeProjectionNoClippingLayerNormLstmTest {};
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
HybridLayerNormLstmBlackBoxTestInt8) { HybridLayerNormLstmBlackBoxTestInt8) {
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
@ -2104,24 +2057,24 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_cell}, // output_layer_norm_coefficient tensor {n_cell}, // output_layer_norm_coefficient tensor
}, },
/*weight_type=*/TensorType_INT8, /*weight_type=*/TensorType_INT8,
/*is_layer_norm=*/true, GetParam()); /*is_layer_norm=*/true);
// Goldens are results using FLOAT32 inference. // Verify the final output.
lstm_golden_output_ = {{ lstm_golden_output_ = {
{
// Batch0: 3 (input_sequence_size) * 3 (n_output) // Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0212971, 0.140816, 0.0112733, // seq 0 0.0212250091, 0.140474007, 0.0115012666, // seq 0
0.0132302, 0.152308, 0.0346313, // seq 1 0.0130806509, 0.152660668, 0.0347516984, // seq 1
-0.0123688, 0.16579, 0.0893078, // seq 2 -0.0124010444, 0.166042402, 0.0898982584, // seq 2
}, },
{ {
// Batch1: 3 (input_sequence_size) * 3 (n_output) // Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.0226351, 0.0916948, 0.0769176, // seq 0 -0.0228835996, 0.0917588323, 0.0778886303, // seq 0
-0.0269967, 0.149708, 0.0941492, // seq 1 -0.0275101066, 0.148769245, 0.0938384682, // seq 1
-0.0103429, 0.173016, 0.0720509, // seq 2 -0.0103605557, 0.172605693, 0.0728750974, // seq 2
}}; }};
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
/*tolerance=*/1e-3);
} }
class LSTMIntegerOpModel : public SingleOpModel { class LSTMIntegerOpModel : public SingleOpModel {
@ -3358,22 +3311,5 @@ TEST(LSTMOpModel, InvalidTypeTest) {
""); "");
} }
#endif #endif
#define QUANTIZE_PARAMETER_TEST(test) \
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::Bool())
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmTest);
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmTest);
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmTest);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmInt8Test);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest);
QUANTIZE_PARAMETER_TEST(
NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmTest);
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
#undef QUANTIZE_PARAMETER_TEST
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -43,7 +43,6 @@ struct OpData {
int effective_scale_1_b; int effective_scale_1_b;
int32 effective_scale_2_a; int32 effective_scale_2_a;
int effective_scale_2_b; int effective_scale_2_b;
bool compute_row_sums = false;
}; };
} // namespace } // namespace
@ -62,8 +61,8 @@ constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* op_data = new OpData();
op_data->float_weights_time_initialized = false; op_data->float_weights_time_initialized = false;
// Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise. // Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise.
context->AddTensors(context, /*tensors_to_add=*/6, context->AddTensors(context, /*tensors_to_add=*/4,
&op_data->scratch_tensor_index); &op_data->scratch_tensor_index);
return op_data; return op_data;
} }
@ -131,7 +130,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize scratch. // Resize scratch.
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) { if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(4);
} else if (is_full_integer) { } else if (is_full_integer) {
node->temporaries = TfLiteIntArrayCreate(2); node->temporaries = TfLiteIntArrayCreate(2);
} else { } else {
@ -157,7 +156,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_size_array)); scratch_size_array));
if (is_hybrid_op) { if (is_hybrid_op) {
op_data->compute_row_sums = true;
// Tell interpreter to allocate temporary tensors to store quantized values // Tell interpreter to allocate temporary tensors to store quantized values
// of input tensors. // of input tensors.
node->temporaries->data[1] = scratch_tensor_index + 1; node->temporaries->data[1] = scratch_tensor_index + 1;
@ -197,30 +195,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, float_weights_time, context->ResizeTensor(context, float_weights_time,
float_weights_time_size)); float_weights_time_size));
} }
node->temporaries->data[4] = scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = zero_points_dims[0];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[5] = scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
row_sums->type = kTfLiteFloat32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[1] = {num_filters};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
row_sums_size->data[0] = row_sums_dims[0];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
if (is_full_integer) { if (is_full_integer) {
// Allocated one extra tensor. // Allocated one extra tensor.
@ -293,8 +267,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/2); GetTemporary(context, node, /*index=*/2);
TfLiteTensor* float_weights_time = TfLiteTensor* float_weights_time =
GetTemporary(context, node, /*index=*/3); GetTemporary(context, node, /*index=*/3);
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
// Dequantize weights time. // Dequantize weights time.
// TODO(alanchiao): this dequantization initialization only needs to // TODO(alanchiao): this dequantization initialization only needs to
// happen once per model and should theoretically be placed in either // happen once per model and should theoretically be placed in either
@ -312,11 +285,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
op_data->float_weights_time_initialized = true; op_data->float_weights_time_initialized = true;
} }
reference_ops::EvalHybridSVDF(context, node, input, weights_feature,
reference_ops::EvalHybridSVDF( float_weights_time, bias, params, scratch,
context, node, input, weights_feature, float_weights_time, bias, scaling_factors, input_quantized,
params, scratch, scaling_factors, input_quantized, activation_state, activation_state, output);
output, zero_points, row_sums, &op_data->compute_row_sums);
return kTfLiteOk; return kTfLiteOk;
} else { } else {
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>( auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(

View File

@ -131,8 +131,7 @@ class BaseSVDFOpModel : public SingleOpModel {
BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
int rank, int rank,
TensorType weights_feature_type = TensorType_FLOAT32, TensorType weights_feature_type = TensorType_FLOAT32,
TensorType weights_time_type = TensorType_FLOAT32, TensorType weights_time_type = TensorType_FLOAT32)
bool asymmetric_quantize_inputs = false)
: batches_(batches), : batches_(batches),
units_(units), units_(units),
input_size_(input_size), input_size_(input_size),
@ -147,10 +146,9 @@ class BaseSVDFOpModel : public SingleOpModel {
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
/*is_variable=*/true); /*is_variable=*/true);
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, SetBuiltinOp(
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE, BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
asymmetric_quantize_inputs) CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
.Union());
BuildInterpreter({ BuildInterpreter({
{batches_, input_size_}, // input tensor {batches_, input_size_}, // input tensor
{units_ * rank, input_size_}, // weights_feature tensor {units_ * rank, input_size_}, // weights_feature tensor
@ -205,10 +203,9 @@ class SVDFOpModel : public BaseSVDFOpModel {
class HybridSVDFOpModel : public BaseSVDFOpModel { class HybridSVDFOpModel : public BaseSVDFOpModel {
public: public:
HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
int rank, TensorType tensor_type, int rank, TensorType tensor_type)
bool asymmetric_quantize_inputs)
: BaseSVDFOpModel(batches, units, input_size, memory_size, rank, : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
tensor_type, tensor_type, asymmetric_quantize_inputs) { tensor_type, tensor_type) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -232,7 +229,7 @@ class HybridSVDFOpModel : public BaseSVDFOpModel {
TensorType tensor_type_; TensorType tensor_type_;
}; };
class SVDFOpTest : public ::testing::TestWithParam<bool> { class SVDFOpTest : public ::testing::Test {
protected: protected:
void VerifyGoldens(float golden_input[], float golden_output[], void VerifyGoldens(float golden_input[], float golden_output[],
int golden_size, BaseSVDFOpModel* svdf, int golden_size, BaseSVDFOpModel* svdf,
@ -265,9 +262,6 @@ class SVDFOpTest : public ::testing::TestWithParam<bool> {
} }
}; };
INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest,
::testing::ValuesIn({false, true}));
TEST_F(SVDFOpTest, BlackBoxTestRank1) { TEST_F(SVDFOpTest, BlackBoxTestRank1) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1); /*memory_size=*/10, /*rank=*/1);
@ -331,10 +325,9 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) {
&svdf); &svdf);
} }
TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1, TensorType_UINT8, /*memory_size=*/10, /*rank=*/1, TensorType_UINT8);
GetParam());
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.22197971, 0.12416199, 0.27901134, 0.27557442,
0.3905206, -0.36137494, -0.06634006, -0.10640851}); 0.3905206, -0.36137494, -0.06634006, -0.10640851});
@ -354,13 +347,12 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.004285); /*tolerance=*/0.002945);
} }
TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2, TensorType_UINT8, /*memory_size=*/10, /*rank=*/2, TensorType_UINT8);
GetParam());
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.12416199, 0.15785322, 0.27901134, 0.3905206,
0.21931258, -0.36137494, -0.10640851, 0.31053296, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
@ -395,13 +387,12 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.007175); /*tolerance=*/0.00625109);
} }
TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) { TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1, TensorType_INT8, /*memory_size=*/10, /*rank=*/1, TensorType_INT8);
GetParam());
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.22197971, 0.12416199, 0.27901134, 0.27557442,
0.3905206, -0.36137494, -0.06634006, -0.10640851}); 0.3905206, -0.36137494, -0.06634006, -0.10640851});
@ -421,13 +412,12 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.004285); /*tolerance=*/0.002945);
} }
TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) { TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2, TensorType_INT8, /*memory_size=*/10, /*rank=*/2, TensorType_INT8);
GetParam());
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.12416199, 0.15785322, 0.27901134, 0.3905206,
0.21931258, -0.36137494, -0.10640851, 0.31053296, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
@ -462,7 +452,7 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
&svdf, &svdf,
/*tolerance=*/0.007175); /*tolerance=*/0.00625109);
} }
// Test case for full integer quantization of SVDF. // Test case for full integer quantization of SVDF.

View File

@ -33,7 +33,6 @@ struct OpData {
bool is_layer_norm_lstm; bool is_layer_norm_lstm;
// The scratch tensor index. // The scratch tensor index.
int scratch_tensor_index; int scratch_tensor_index;
bool compute_row_sums = false;
}; };
// Input Tensors of size {max_time, n_batch, n_input} // Input Tensors of size {max_time, n_batch, n_input}
@ -93,9 +92,7 @@ enum TemporaryTensor {
kProductScalingFactors = 5, kProductScalingFactors = 5,
kRecoveredCellWeights = 6, kRecoveredCellWeights = 6,
kAccumScratch = 7, kAccumScratch = 7,
kZeroPoints = 8, kNumTemporaryTensors
kRowSums = 9,
kNumTemporaryTensors = 10
}; };
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -411,7 +408,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_buffer_size)); scratch_buffer_size));
if (IsHybridOp(input, input_to_output_weights)) { if (IsHybridOp(input, input_to_output_weights)) {
op_data->compute_row_sums = true;
// Allocate temporary tensors to store quantized values of input, // Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors. // activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] = node->temporaries->data[kInputQuantized] =
@ -519,34 +515,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size)); context, context->ResizeTensor(context, accum_scratch, accum_size));
} }
node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_rows = use_cifg ? 6 : 8;
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
if (projection_weights != nullptr) {
row_sums_rows += ceil(n_output / n_cell);
}
int row_sums_dims[2] = {row_sums_rows, n_cell};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -632,7 +600,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
lstm_params.activation = params->activation; lstm_params.activation = params->activation;
lstm_params.cell_clip = params->cell_clip; lstm_params.cell_clip = params->cell_clip;
lstm_params.proj_clip = params->proj_clip; lstm_params.proj_clip = params->proj_clip;
lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
switch (input_to_output_weights->type) { switch (input_to_output_weights->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -656,7 +623,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* activation_state_quantized = TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2); GetTemporary(context, node, /*index=*/2);
@ -669,10 +635,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/6); GetTemporary(context, node, /*index=*/6);
TfLiteTensor* accum_scratch = TfLiteTensor* accum_scratch =
GetTemporary(context, node, /*index=*/kAccumScratch); GetTemporary(context, node, /*index=*/kAccumScratch);
TfLiteTensor* zero_points =
GetTemporary(context, node, /*index=*/kZeroPoints);
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/kRowSums);
const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid( return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights, input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights, input_to_cell_weights, input_to_output_weights,
@ -692,9 +654,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
prod_scaling_factors, recovered_cell_weights, input_quantized, prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, accum_scratch, cell_state_quantized, activation_state, cell_state, accum_scratch,
output, zero_points, row_sums, row_sums_size, output, CpuBackendContext::GetFromContext(context));
&op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
} }
default: default:
context->ReportError(context, "Type %d is not currently supported.", context->ReportError(context, "Type %d is not currently supported.",

View File

@ -38,8 +38,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
float proj_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes, const std::vector<std::vector<int>>& input_shapes,
const TensorType& weights_type = TensorType_FLOAT32, const TensorType& weights_type = TensorType_FLOAT32,
bool is_layer_norm = false, bool is_layer_norm = false)
bool asymmetric_quantize_inputs = false)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_cell_(n_cell), n_cell_(n_cell),
@ -132,7 +131,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
BuiltinOptions_UnidirectionalSequenceLSTMOptions, BuiltinOptions_UnidirectionalSequenceLSTMOptions,
CreateUnidirectionalSequenceLSTMOptions( CreateUnidirectionalSequenceLSTMOptions(
builder_, ActivationFunctionType_TANH, cell_clip, builder_, ActivationFunctionType_TANH, cell_clip,
proj_clip, time_major, asymmetric_quantize_inputs) proj_clip, time_major)
.Union()); .Union());
BuildInterpreter(input_shapes); BuildInterpreter(input_shapes);
} }
@ -293,12 +292,11 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
bool time_major, bool use_cifg, bool use_peephole, bool time_major, bool use_cifg, bool use_peephole,
bool use_projection_weights, bool use_projection_bias, float cell_clip, bool use_projection_weights, bool use_projection_bias, float cell_clip,
float proj_clip, const std::vector<std::vector<int>>& input_shapes, float proj_clip, const std::vector<std::vector<int>>& input_shapes,
TensorType tensor_type, bool asymmetric_quantize_inputs) TensorType tensor_type)
: UnidirectionalLSTMOpModel( : UnidirectionalLSTMOpModel(
n_batch, n_input, n_cell, n_output, sequence_length, time_major, n_batch, n_input, n_cell, n_output, sequence_length, time_major,
use_cifg, use_peephole, use_projection_weights, use_projection_bias, use_cifg, use_peephole, use_projection_weights, use_projection_bias,
cell_clip, proj_clip, input_shapes, tensor_type, false, cell_clip, proj_clip, input_shapes, tensor_type) {
asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -362,7 +360,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
TensorType tensor_type_; TensorType tensor_type_;
}; };
class BaseUnidirectionalLstmTest : public ::testing::TestWithParam<bool> { class BaseUnidirectionalLstmTest : public ::testing::Test {
protected: protected:
// Weights of the LSTM model. Some are optional. // Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_; std::vector<float> input_to_input_weights_;
@ -628,7 +626,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
/*time_major=*/false); /*time_major=*/false);
} }
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -670,7 +668,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_UINT8, GetParam()); TensorType_UINT8);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -691,7 +689,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
/*tolerance=*/0.0157651); /*tolerance=*/0.0157651);
} }
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -733,7 +731,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_INT8, GetParam()); TensorType_INT8);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -864,7 +862,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -886,6 +884,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_cell, n_input}, // input_to_forget_weight tensor {n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor {n_cell, n_input}, // input_to_output_weight tensor
{0, 0}, // recurrent_to_input_weight tensor {0, 0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor {n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor {n_cell, n_output}, // recurrent_to_cell_weight tensor
@ -906,7 +905,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_UINT8, GetParam()); TensorType_UINT8);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_);
@ -926,7 +925,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
} }
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -969,7 +968,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_INT8, GetParam()); TensorType_INT8);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_);
@ -1656,16 +1655,14 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestUint8) { HybridLstmBlackBoxTestUint8) {
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
const int n_output = 16; const int n_output = 16;
const int sequence_length = 4; const int sequence_length = 4;
if (GetParam()) {
return;
}
HybridUnidirectionalLSTMOpModel lstm( HybridUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
@ -1700,7 +1697,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_UINT8, GetParam()); TensorType_UINT8);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -1726,11 +1723,8 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
} }
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
HybridLstmBlackBoxTestInt8) { HybridLstmBlackBoxTestInt8) {
if (GetParam()) {
return;
}
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
@ -1771,7 +1765,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_batch, n_output}, // activation_state tensor {n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor {n_batch, n_cell}, // cell_state tensor
}, },
TensorType_INT8, GetParam()); TensorType_INT8);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -2743,14 +2737,5 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
#define QUANTIZE_PARAMETER_TEST(test) \
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true}));
QUANTIZE_PARAMETER_TEST(
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
QUANTIZE_PARAMETER_TEST(
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest);
#undef QUANTIZE_PARAMETER_TEST
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -26,15 +26,6 @@ namespace ops {
namespace builtin { namespace builtin {
namespace unidirectional_sequence_rnn { namespace unidirectional_sequence_rnn {
namespace {
struct OpData {
int scratch_tensor_index;
bool compute_row_sums = false;
};
} // namespace
// Input tensors. // Input tensors.
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1; constexpr int kWeightsTensor = 1;
@ -46,14 +37,13 @@ constexpr int kHiddenStateTensor = 4;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData(); auto* scratch_tensor_index = new int;
context->AddTensors(context, /*tensors_to_add=*/6, context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
&op_data->scratch_tensor_index); return scratch_tensor_index;
return op_data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<int*>(buffer);
} }
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@ -106,11 +96,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store quantized values of input and // Allocate temporary tensors to store quantized values of input and
// hidden_state tensors. // hidden_state tensors.
if (is_hybrid) { if (is_hybrid) {
auto* op_data = reinterpret_cast<OpData*>(node->user_data); int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
op_data->compute_row_sums = true;
TfLiteIntArrayFree(node->temporaries); TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(6); node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = op_data->scratch_tensor_index; node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = input_weights->type; input_quantized->type = input_weights->type;
input_quantized->allocation_type = kTfLiteArenaRw; input_quantized->allocation_type = kTfLiteArenaRw;
@ -119,7 +108,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
input_quantized_size)); input_quantized_size));
} }
node->temporaries->data[1] = op_data->scratch_tensor_index + 1; node->temporaries->data[1] = *scratch_tensor_index + 1;
TfLiteTensor* hidden_state_quantized = TfLiteTensor* hidden_state_quantized =
GetTemporary(context, node, /*index=*/1); GetTemporary(context, node, /*index=*/1);
hidden_state_quantized->type = input_weights->type; hidden_state_quantized->type = input_weights->type;
@ -132,7 +121,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized, context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size)); hidden_state_quantized_size));
} }
node->temporaries->data[2] = op_data->scratch_tensor_index + 2; node->temporaries->data[2] = *scratch_tensor_index + 2;
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
scaling_factors->type = kTfLiteFloat32; scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw; scaling_factors->allocation_type = kTfLiteArenaRw;
@ -143,42 +132,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size)); scaling_factors_size));
} }
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = accum_scratch_dims[0];
accum_scratch_size->data[1] = accum_scratch_dims[1];
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
zero_points->type = kTfLiteInt32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = batch_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
}
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_dims[2] = {2, num_units};
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
row_sums_size->data[0] = row_sums_dims[0];
row_sums_size->data[1] = row_sums_dims[1];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -249,9 +202,7 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* hidden_state, TfLiteTensor* output) {
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
bool* compute_row_sums) {
const bool time_major = params->time_major; const bool time_major = params->time_major;
const int batch_size = const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0]; (time_major) ? input->dims->data[1] : input->dims->data[0];
@ -276,14 +227,6 @@ TfLiteStatus EvalHybrid(
float input_weights_scale = input_weights->params.scale; float input_weights_scale = input_weights->params.scale;
float recurrent_weights_scale = recurrent_weights->params.scale; float recurrent_weights_scale = recurrent_weights->params.scale;
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
int32_t* zero_points_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
if (time_major) { if (time_major) {
// Initialize the pointer to hidden state. // Initialize the pointer to hidden state.
@ -301,9 +244,7 @@ TfLiteStatus EvalHybrid(
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, num_units, params->activation, num_units, batch_size, num_units, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
hidden_state_ptr_batch, output_ptr_batch, hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
} }
} else { } else {
// For each batch // For each batch
@ -318,14 +259,13 @@ TfLiteStatus EvalHybrid(
s * input_size; s * input_size;
float* output_ptr_batch = GetTensorData<float>(output) + float* output_ptr_batch = GetTensorData<float>(output) +
b * num_units * max_time + s * num_units; b * num_units * max_time + s * num_units;
kernel_utils::RnnBatchStep( kernel_utils::RnnBatchStep(
input_ptr_batch, input_weights_ptr, input_weights_scale, input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
input_size, num_units, /*batch_size=*/1, num_units, input_size, num_units, /*batch_size=*/1, num_units,
params->activation, quantized_input_ptr, quantized_hidden_state_ptr, params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch, scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch);
params->asymmetric_quantize_inputs, zero_points_ptr,
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
} }
} }
} }
@ -334,6 +274,7 @@ TfLiteStatus EvalHybrid(
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights = const TfLiteTensor* recurrent_weights =
@ -351,17 +292,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: case kTfLiteUInt8:
case kTfLiteInt8: { case kTfLiteInt8: {
// TODO(mirkov): implement eval with quantized inputs as well. // TODO(mirkov): implement eval with quantized inputs as well.
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
return EvalHybrid(input, input_weights, recurrent_weights, bias, params, return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
input_quantized, hidden_state_quantized, input_quantized, hidden_state_quantized,
scaling_factors, hidden_state, output, zero_points, scaling_factors, hidden_state, output);
accum_scratch, row_sums, &op_data->compute_row_sums);
} }
default: default:
context->ReportError(context, "Type %d not currently supported.", context->ReportError(context, "Type %d not currently supported.",

View File

@ -174,8 +174,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
UnidirectionalRNNOpModel( UnidirectionalRNNOpModel(
int batches, int sequence_len, int units, int size, bool time_major, int batches, int sequence_len, int units, int size, bool time_major,
const TensorType& weights = TensorType_FLOAT32, const TensorType& weights = TensorType_FLOAT32,
const TensorType& recurrent_weights = TensorType_FLOAT32, const TensorType& recurrent_weights = TensorType_FLOAT32)
bool asymmetric_quantize_inputs = false)
: batches_(batches), : batches_(batches),
sequence_len_(sequence_len), sequence_len_(sequence_len),
units_(units), units_(units),
@ -189,8 +188,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions, BuiltinOptions_SequenceRNNOptions,
CreateSequenceRNNOptions(builder_, time_major, CreateSequenceRNNOptions(builder_, time_major,
ActivationFunctionType_RELU, ActivationFunctionType_RELU)
asymmetric_quantize_inputs)
.Union()); .Union());
if (time_major) { if (time_major) {
BuildInterpreter({{sequence_len_, batches_, input_size_}, BuildInterpreter({{sequence_len_, batches_, input_size_},
@ -251,11 +249,9 @@ class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel {
public: public:
HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units, HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
int size, bool time_major, int size, bool time_major,
TensorType tensor_type, TensorType tensor_type)
bool asymmetric_quantize_inputs)
: UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major, : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
tensor_type, tensor_type, tensor_type, tensor_type) {
asymmetric_quantize_inputs) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
} }
@ -301,14 +297,10 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
} }
class HybridUnidirectionalRNNOpModelOpTest TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
: public ::testing::TestWithParam<bool> {};
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/false, TensorType_UINT8, /*time_major=*/false, TensorType_UINT8);
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -331,11 +323,10 @@ TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
expected, /*max_abs_error=*/0.013))); expected, /*max_abs_error=*/0.013)));
} }
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) { TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/false, TensorType_INT8, /*time_major=*/false, TensorType_INT8);
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -387,11 +378,10 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
} }
TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/true, TensorType_UINT8, /*time_major=*/true, TensorType_UINT8);
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -418,11 +408,10 @@ TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
expected, /*max_abs_error=*/0.013))); expected, /*max_abs_error=*/0.013)));
} }
TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*units=*/16, /*size=*/8,
/*time_major=*/true, TensorType_INT8, /*time_major=*/true, TensorType_INT8);
GetParam());
rnn.SetWeights(rnn_weights); rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias); rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.SetRecurrentWeights(rnn_recurrent_weights);
@ -449,9 +438,5 @@ TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
expected, /*max_abs_error=*/0.013))); expected, /*max_abs_error=*/0.013)));
} }
INSTANTIATE_TEST_SUITE_P(HybridUnidirectionalRNNOpModelOpTest,
HybridUnidirectionalRNNOpModelOpTest,
::testing::ValuesIn({true, false}));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -519,22 +519,17 @@ table LSHProjectionOptions {
table SVDFOptions { table SVDFOptions {
rank:int; rank:int;
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
// For weights-only quantization, use asymmetric quantization for non
// constant inputs at evaluation time.
asymmetric_quantize_inputs:bool;
} }
// An implementation of TensorFlow RNNCell. // An implementation of TensorFlow RNNCell.
table RNNOptions { table RNNOptions {
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
asymmetric_quantize_inputs:bool;
} }
// An implementation of TensorFlow dynamic_rnn with RNNCell. // An implementation of TensorFlow dynamic_rnn with RNNCell.
table SequenceRNNOptions { table SequenceRNNOptions {
time_major:bool; time_major:bool;
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
asymmetric_quantize_inputs:bool;
} }
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
@ -542,7 +537,6 @@ table BidirectionalSequenceRNNOptions {
time_major:bool; time_major:bool;
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
merge_outputs: bool; merge_outputs: bool;
asymmetric_quantize_inputs:bool;
} }
enum FullyConnectedOptionsWeightsFormat: byte { enum FullyConnectedOptionsWeightsFormat: byte {
@ -562,11 +556,6 @@ table FullyConnectedOptions {
// If set to true, then the number of dimension is preserved. Furthermore, // If set to true, then the number of dimension is preserved. Furthermore,
// all but the last dimension of the input and output shapes will be equal. // all but the last dimension of the input and output shapes will be equal.
keep_num_dims: bool; keep_num_dims: bool;
// Parameters for FullyConnected version 7 or above.
// If set to true, then weights-only op will use asymmetric quantization for
// inputs.
asymmetric_quantize_inputs: bool;
} }
table SoftmaxOptions { table SoftmaxOptions {
@ -615,9 +604,6 @@ table LSTMOptions {
// Parameters for LSTM version 2 or above. // Parameters for LSTM version 2 or above.
// Basic kernel is only supported in version 2 or above. // Basic kernel is only supported in version 2 or above.
kernel_type: LSTMKernelType = FULL; kernel_type: LSTMKernelType = FULL;
// Parameters for LSTM version 4 or above.
asymmetric_quantize_inputs: bool;
} }
// An implementation of TensorFlow dynamic_rnn with LSTMCell. // An implementation of TensorFlow dynamic_rnn with LSTMCell.
@ -628,9 +614,6 @@ table UnidirectionalSequenceLSTMOptions {
// If true then first dimension is sequence, otherwise batch. // If true then first dimension is sequence, otherwise batch.
time_major:bool; time_major:bool;
// Parameter for Unidirectional Sequence LSTM version 4.
asymmetric_quantize_inputs:bool;
} }
table BidirectionalSequenceLSTMOptions { table BidirectionalSequenceLSTMOptions {
@ -647,9 +630,6 @@ table BidirectionalSequenceLSTMOptions {
// Version 1 implementations assumed time_major to be true, so this default // Version 1 implementations assumed time_major to be true, so this default
// value should never change. // value should never change.
time_major: bool = true; time_major: bool = true;
// Parameters for version 3 or above.
asymmetric_quantize_inputs:bool;
} }
table ResizeBilinearOptions { table ResizeBilinearOptions {

View File

@ -4216,11 +4216,9 @@ struct SVDFOptionsT : public flatbuffers::NativeTable {
typedef SVDFOptions TableType; typedef SVDFOptions TableType;
int32_t rank; int32_t rank;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool asymmetric_quantize_inputs;
SVDFOptionsT() SVDFOptionsT()
: rank(0), : rank(0),
fused_activation_function(tflite::ActivationFunctionType_NONE), fused_activation_function(tflite::ActivationFunctionType_NONE) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4228,8 +4226,7 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef SVDFOptionsT NativeTableType; typedef SVDFOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_RANK = 4, VT_RANK = 4,
VT_FUSED_ACTIVATION_FUNCTION = 6, VT_FUSED_ACTIVATION_FUNCTION = 6
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
}; };
int32_t rank() const { int32_t rank() const {
return GetField<int32_t>(VT_RANK, 0); return GetField<int32_t>(VT_RANK, 0);
@ -4237,14 +4234,10 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_RANK) && VerifyField<int32_t>(verifier, VT_RANK) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4261,9 +4254,6 @@ struct SVDFOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4279,11 +4269,9 @@ struct SVDFOptionsBuilder {
inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions( inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
int32_t rank = 0, int32_t rank = 0,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
bool asymmetric_quantize_inputs = false) {
SVDFOptionsBuilder builder_(_fbb); SVDFOptionsBuilder builder_(_fbb);
builder_.add_rank(rank); builder_.add_rank(rank);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
} }
@ -4293,29 +4281,22 @@ flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBufferBuilde
struct RNNOptionsT : public flatbuffers::NativeTable { struct RNNOptionsT : public flatbuffers::NativeTable {
typedef RNNOptions TableType; typedef RNNOptions TableType;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool asymmetric_quantize_inputs;
RNNOptionsT() RNNOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE) {
asymmetric_quantize_inputs(false) {
} }
}; };
struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef RNNOptionsT NativeTableType; typedef RNNOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4
VT_ASYMMETRIC_QUANTIZE_INPUTS = 6
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4329,9 +4310,6 @@ struct RNNOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4346,10 +4324,8 @@ struct RNNOptionsBuilder {
inline flatbuffers::Offset<RNNOptions> CreateRNNOptions( inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
bool asymmetric_quantize_inputs = false) {
RNNOptionsBuilder builder_(_fbb); RNNOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
} }
@ -4360,11 +4336,9 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable {
typedef SequenceRNNOptions TableType; typedef SequenceRNNOptions TableType;
bool time_major; bool time_major;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool asymmetric_quantize_inputs;
SequenceRNNOptionsT() SequenceRNNOptionsT()
: time_major(false), : time_major(false),
fused_activation_function(tflite::ActivationFunctionType_NONE), fused_activation_function(tflite::ActivationFunctionType_NONE) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4372,8 +4346,7 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef SequenceRNNOptionsT NativeTableType; typedef SequenceRNNOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_TIME_MAJOR = 4, VT_TIME_MAJOR = 4,
VT_FUSED_ACTIVATION_FUNCTION = 6, VT_FUSED_ACTIVATION_FUNCTION = 6
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
}; };
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
@ -4381,14 +4354,10 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4405,9 +4374,6 @@ struct SequenceRNNOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4423,10 +4389,8 @@ struct SequenceRNNOptionsBuilder {
inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions( inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
bool time_major = false, bool time_major = false,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
bool asymmetric_quantize_inputs = false) {
SequenceRNNOptionsBuilder builder_(_fbb); SequenceRNNOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
return builder_.Finish(); return builder_.Finish();
@ -4439,12 +4403,10 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable {
bool time_major; bool time_major;
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
bool merge_outputs; bool merge_outputs;
bool asymmetric_quantize_inputs;
BidirectionalSequenceRNNOptionsT() BidirectionalSequenceRNNOptionsT()
: time_major(false), : time_major(false),
fused_activation_function(tflite::ActivationFunctionType_NONE), fused_activation_function(tflite::ActivationFunctionType_NONE),
merge_outputs(false), merge_outputs(false) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4453,8 +4415,7 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_TIME_MAJOR = 4, VT_TIME_MAJOR = 4,
VT_FUSED_ACTIVATION_FUNCTION = 6, VT_FUSED_ACTIVATION_FUNCTION = 6,
VT_MERGE_OUTPUTS = 8, VT_MERGE_OUTPUTS = 8
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
}; };
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
@ -4465,15 +4426,11 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
bool merge_outputs() const { bool merge_outputs() const {
return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) && VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4493,9 +4450,6 @@ struct BidirectionalSequenceRNNOptionsBuilder {
void add_merge_outputs(bool merge_outputs) { void add_merge_outputs(bool merge_outputs) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0); fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4512,10 +4466,8 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
bool time_major = false, bool time_major = false,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
bool merge_outputs = false, bool merge_outputs = false) {
bool asymmetric_quantize_inputs = false) {
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_merge_outputs(merge_outputs); builder_.add_merge_outputs(merge_outputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
@ -4529,12 +4481,10 @@ struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
tflite::ActivationFunctionType fused_activation_function; tflite::ActivationFunctionType fused_activation_function;
tflite::FullyConnectedOptionsWeightsFormat weights_format; tflite::FullyConnectedOptionsWeightsFormat weights_format;
bool keep_num_dims; bool keep_num_dims;
bool asymmetric_quantize_inputs;
FullyConnectedOptionsT() FullyConnectedOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT), weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT),
keep_num_dims(false), keep_num_dims(false) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -4543,8 +4493,7 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_WEIGHTS_FORMAT = 6, VT_WEIGHTS_FORMAT = 6,
VT_KEEP_NUM_DIMS = 8, VT_KEEP_NUM_DIMS = 8
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -4555,15 +4504,11 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
bool keep_num_dims() const { bool keep_num_dims() const {
return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0; return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) && VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) && VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -4583,9 +4528,6 @@ struct FullyConnectedOptionsBuilder {
void add_keep_num_dims(bool keep_num_dims) { void add_keep_num_dims(bool keep_num_dims) {
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0); fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -4602,10 +4544,8 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
bool keep_num_dims = false, bool keep_num_dims = false) {
bool asymmetric_quantize_inputs = false) {
FullyConnectedOptionsBuilder builder_(_fbb); FullyConnectedOptionsBuilder builder_(_fbb);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_keep_num_dims(keep_num_dims); builder_.add_keep_num_dims(keep_num_dims);
builder_.add_weights_format(weights_format); builder_.add_weights_format(weights_format);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
@ -4992,13 +4932,11 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
float cell_clip; float cell_clip;
float proj_clip; float proj_clip;
tflite::LSTMKernelType kernel_type; tflite::LSTMKernelType kernel_type;
bool asymmetric_quantize_inputs;
LSTMOptionsT() LSTMOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f), proj_clip(0.0f),
kernel_type(tflite::LSTMKernelType_FULL), kernel_type(tflite::LSTMKernelType_FULL) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -5008,8 +4946,7 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8, VT_PROJ_CLIP = 8,
VT_KERNEL_TYPE = 10, VT_KERNEL_TYPE = 10
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -5023,16 +4960,12 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
tflite::LSTMKernelType kernel_type() const { tflite::LSTMKernelType kernel_type() const {
return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0)); return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) && VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -5055,9 +4988,6 @@ struct LSTMOptionsBuilder {
void add_kernel_type(tflite::LSTMKernelType kernel_type) { void add_kernel_type(tflite::LSTMKernelType kernel_type) {
fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0); fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -5075,12 +5005,10 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f, float proj_clip = 0.0f,
tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL, tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) {
bool asymmetric_quantize_inputs = false) {
LSTMOptionsBuilder builder_(_fbb); LSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_kernel_type(kernel_type); builder_.add_kernel_type(kernel_type);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
@ -5094,13 +5022,11 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
float cell_clip; float cell_clip;
float proj_clip; float proj_clip;
bool time_major; bool time_major;
bool asymmetric_quantize_inputs;
UnidirectionalSequenceLSTMOptionsT() UnidirectionalSequenceLSTMOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f), proj_clip(0.0f),
time_major(false), time_major(false) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -5110,8 +5036,7 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8, VT_PROJ_CLIP = 8,
VT_TIME_MAJOR = 10, VT_TIME_MAJOR = 10
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -5125,16 +5050,12 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -5157,9 +5078,6 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
void add_time_major(bool time_major) { void add_time_major(bool time_major) {
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0); fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -5177,12 +5095,10 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f, float proj_clip = 0.0f,
bool time_major = false, bool time_major = false) {
bool asymmetric_quantize_inputs = false) {
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
@ -5197,14 +5113,12 @@ struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
float proj_clip; float proj_clip;
bool merge_outputs; bool merge_outputs;
bool time_major; bool time_major;
bool asymmetric_quantize_inputs;
BidirectionalSequenceLSTMOptionsT() BidirectionalSequenceLSTMOptionsT()
: fused_activation_function(tflite::ActivationFunctionType_NONE), : fused_activation_function(tflite::ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f), proj_clip(0.0f),
merge_outputs(false), merge_outputs(false),
time_major(true), time_major(true) {
asymmetric_quantize_inputs(false) {
} }
}; };
@ -5215,8 +5129,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8, VT_PROJ_CLIP = 8,
VT_MERGE_OUTPUTS = 10, VT_MERGE_OUTPUTS = 10,
VT_TIME_MAJOR = 12, VT_TIME_MAJOR = 12
VT_ASYMMETRIC_QUANTIZE_INPUTS = 14
}; };
tflite::ActivationFunctionType fused_activation_function() const { tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -5233,9 +5146,6 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
bool time_major() const { bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0; return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0;
} }
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
@ -5243,7 +5153,6 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) && VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
verifier.EndTable(); verifier.EndTable();
} }
BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -5269,9 +5178,6 @@ struct BidirectionalSequenceLSTMOptionsBuilder {
void add_time_major(bool time_major) { void add_time_major(bool time_major) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1); fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1);
} }
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -5290,12 +5196,10 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f, float proj_clip = 0.0f,
bool merge_outputs = false, bool merge_outputs = false,
bool time_major = true, bool time_major = true) {
bool asymmetric_quantize_inputs = false) {
BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_time_major(time_major); builder_.add_time_major(time_major);
builder_.add_merge_outputs(merge_outputs); builder_.add_merge_outputs(merge_outputs);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
@ -11130,7 +11034,6 @@ inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_
(void)_resolver; (void)_resolver;
{ auto _e = rank(); _o->rank = _e; } { auto _e = rank(); _o->rank = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11143,12 +11046,10 @@ inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBuffe
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _rank = _o->rank; auto _rank = _o->rank;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateSVDFOptions( return tflite::CreateSVDFOptions(
_fbb, _fbb,
_rank, _rank,
_fused_activation_function, _fused_activation_function);
_asymmetric_quantize_inputs);
} }
inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11161,7 +11062,6 @@ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_fu
(void)_o; (void)_o;
(void)_resolver; (void)_resolver;
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11173,11 +11073,9 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(flatbuffers::FlatBufferB
(void)_o; (void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateRNNOptions( return tflite::CreateRNNOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function);
_asymmetric_quantize_inputs);
} }
inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11191,7 +11089,6 @@ inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuff
(void)_resolver; (void)_resolver;
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11204,12 +11101,10 @@ inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(flatbuff
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateSequenceRNNOptions( return tflite::CreateSequenceRNNOptions(
_fbb, _fbb,
_time_major, _time_major,
_fused_activation_function, _fused_activation_function);
_asymmetric_quantize_inputs);
} }
inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11224,7 +11119,6 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = merge_outputs(); _o->merge_outputs = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11238,13 +11132,11 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _merge_outputs = _o->merge_outputs; auto _merge_outputs = _o->merge_outputs;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateBidirectionalSequenceRNNOptions( return tflite::CreateBidirectionalSequenceRNNOptions(
_fbb, _fbb,
_time_major, _time_major,
_fused_activation_function, _fused_activation_function,
_merge_outputs, _merge_outputs);
_asymmetric_quantize_inputs);
} }
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11259,7 +11151,6 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = weights_format(); _o->weights_format = _e; } { auto _e = weights_format(); _o->weights_format = _e; }
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; } { auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11273,13 +11164,11 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(fl
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _weights_format = _o->weights_format; auto _weights_format = _o->weights_format;
auto _keep_num_dims = _o->keep_num_dims; auto _keep_num_dims = _o->keep_num_dims;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateFullyConnectedOptions( return tflite::CreateFullyConnectedOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_weights_format, _weights_format,
_keep_num_dims, _keep_num_dims);
_asymmetric_quantize_inputs);
} }
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11463,7 +11352,6 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
{ auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = cell_clip(); _o->cell_clip = _e; }
{ auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; }
{ auto _e = kernel_type(); _o->kernel_type = _e; } { auto _e = kernel_type(); _o->kernel_type = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11478,14 +11366,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
auto _cell_clip = _o->cell_clip; auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _kernel_type = _o->kernel_type; auto _kernel_type = _o->kernel_type;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateLSTMOptions( return tflite::CreateLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip, _proj_clip,
_kernel_type, _kernel_type);
_asymmetric_quantize_inputs);
} }
inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11501,7 +11387,6 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
{ auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = cell_clip(); _o->cell_clip = _e; }
{ auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; }
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11516,14 +11401,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
auto _cell_clip = _o->cell_clip; auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateUnidirectionalSequenceLSTMOptions( return tflite::CreateUnidirectionalSequenceLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip, _proj_clip,
_time_major, _time_major);
_asymmetric_quantize_inputs);
} }
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -11540,7 +11423,6 @@ inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTM
{ auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; }
{ auto _e = merge_outputs(); _o->merge_outputs = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; }
{ auto _e = time_major(); _o->time_major = _e; } { auto _e = time_major(); _o->time_major = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
} }
inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11556,15 +11438,13 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _merge_outputs = _o->merge_outputs; auto _merge_outputs = _o->merge_outputs;
auto _time_major = _o->time_major; auto _time_major = _o->time_major;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
return tflite::CreateBidirectionalSequenceLSTMOptions( return tflite::CreateBidirectionalSequenceLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip, _proj_clip,
_merge_outputs, _merge_outputs,
_time_major, _time_major);
_asymmetric_quantize_inputs);
} }
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {