Adds asymmetric quantized inputs for hybrid ops in future models.
PiperOrigin-RevId: 304559648 Change-Id: I8028ae6f65308c9b9fa928b8d755919af1faa7be
This commit is contained in:
parent
0752177439
commit
bb130adb39
@ -124,21 +124,33 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
int rank;
|
int rank;
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
|
|
||||||
|
// Parameter for SVDF version 4.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteSVDFParams;
|
} TfLiteSVDFParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
|
|
||||||
|
// Parameter for RNN version 3.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteRNNParams;
|
} TfLiteRNNParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
bool time_major;
|
bool time_major;
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
|
|
||||||
|
// Parameter for Sequence RNN version 3.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteSequenceRNNParams;
|
} TfLiteSequenceRNNParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
bool time_major;
|
bool time_major;
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
bool merge_outputs;
|
bool merge_outputs;
|
||||||
|
|
||||||
|
// Parameter for Bidirectional RNN verison 3.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteBidirectionalSequenceRNNParams;
|
} TfLiteBidirectionalSequenceRNNParams;
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
@ -158,6 +170,11 @@ typedef struct {
|
|||||||
// tensors are the same. Furthermore, all but the last dimension of the input
|
// tensors are the same. Furthermore, all but the last dimension of the input
|
||||||
// and output shapes will be equal.
|
// and output shapes will be equal.
|
||||||
bool keep_num_dims;
|
bool keep_num_dims;
|
||||||
|
|
||||||
|
// Parameters for FullyConnected version 7 or above.
|
||||||
|
// If set to true and the weights are quantized, then non constant inputs
|
||||||
|
// are quantized at evaluation time with asymmetric quantization.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteFullyConnectedParams;
|
} TfLiteFullyConnectedParams;
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
@ -228,6 +245,9 @@ typedef struct {
|
|||||||
// Parameters for LSTM version 2.
|
// Parameters for LSTM version 2.
|
||||||
// kTfLiteLSTMBasicKernel is only supported in version 2 or above.
|
// kTfLiteLSTMBasicKernel is only supported in version 2 or above.
|
||||||
TfLiteLSTMKernelType kernel_type;
|
TfLiteLSTMKernelType kernel_type;
|
||||||
|
|
||||||
|
// Parameters for LSTM version 4.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteLSTMParams;
|
} TfLiteLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -238,6 +258,9 @@ typedef struct {
|
|||||||
|
|
||||||
// If set to true then the first dimension is time, otherwise batch.
|
// If set to true then the first dimension is time, otherwise batch.
|
||||||
bool time_major;
|
bool time_major;
|
||||||
|
|
||||||
|
// Parameter for unidirectional sequence RNN version 3.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteUnidirectionalSequenceLSTMParams;
|
} TfLiteUnidirectionalSequenceLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -253,6 +276,10 @@ typedef struct {
|
|||||||
// Parameters supported by version 2:
|
// Parameters supported by version 2:
|
||||||
// If set to true then the first dimension is time, otherwise batch.
|
// If set to true then the first dimension is time, otherwise batch.
|
||||||
bool time_major;
|
bool time_major;
|
||||||
|
|
||||||
|
// Parameters supported by version 4:
|
||||||
|
// If set to true, then hybrid ops use asymmetric quantization for inputs.
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
} TfLiteBidirectionalSequenceLSTMParams;
|
} TfLiteBidirectionalSequenceLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -269,6 +269,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->rank = svdf_params->rank();
|
params->rank = svdf_params->rank();
|
||||||
params->activation =
|
params->activation =
|
||||||
parse_activation(svdf_params->fused_activation_function());
|
parse_activation(svdf_params->fused_activation_function());
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
svdf_params->asymmetric_quantize_inputs();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -280,6 +282,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->activation =
|
params->activation =
|
||||||
parse_activation(sequence_rnn_params->fused_activation_function());
|
parse_activation(sequence_rnn_params->fused_activation_function());
|
||||||
params->time_major = sequence_rnn_params->time_major();
|
params->time_major = sequence_rnn_params->time_major();
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
sequence_rnn_params->asymmetric_quantize_inputs();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -293,6 +297,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
bidi_sequence_rnn_params->fused_activation_function());
|
bidi_sequence_rnn_params->fused_activation_function());
|
||||||
params->time_major = bidi_sequence_rnn_params->time_major();
|
params->time_major = bidi_sequence_rnn_params->time_major();
|
||||||
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
|
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
bidi_sequence_rnn_params->asymmetric_quantize_inputs();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -302,6 +308,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
|
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
|
||||||
params->activation =
|
params->activation =
|
||||||
parse_activation(rnn_params->fused_activation_function());
|
parse_activation(rnn_params->fused_activation_function());
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
rnn_params->asymmetric_quantize_inputs();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -323,6 +331,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->activation = parse_activation(
|
params->activation = parse_activation(
|
||||||
fully_connected_params->fused_activation_function());
|
fully_connected_params->fused_activation_function());
|
||||||
params->keep_num_dims = fully_connected_params->keep_num_dims();
|
params->keep_num_dims = fully_connected_params->keep_num_dims();
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
fully_connected_params->asymmetric_quantize_inputs();
|
||||||
switch (fully_connected_params->weights_format()) {
|
switch (fully_connected_params->weights_format()) {
|
||||||
case FullyConnectedOptionsWeightsFormat_DEFAULT:
|
case FullyConnectedOptionsWeightsFormat_DEFAULT:
|
||||||
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
|
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
|
||||||
@ -440,6 +450,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
lstm_params->kernel_type());
|
lstm_params->kernel_type());
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
lstm_params->asymmetric_quantize_inputs();
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_REPORT_ERROR(error_reporter,
|
TF_LITE_REPORT_ERROR(error_reporter,
|
||||||
"No valid LSTM builtin options exist");
|
"No valid LSTM builtin options exist");
|
||||||
@ -458,6 +470,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->cell_clip = seq_lstm_params->cell_clip();
|
params->cell_clip = seq_lstm_params->cell_clip();
|
||||||
params->proj_clip = seq_lstm_params->proj_clip();
|
params->proj_clip = seq_lstm_params->proj_clip();
|
||||||
params->time_major = seq_lstm_params->time_major();
|
params->time_major = seq_lstm_params->time_major();
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
seq_lstm_params->asymmetric_quantize_inputs();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -473,6 +487,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->proj_clip = bidi_lstm_params->proj_clip();
|
params->proj_clip = bidi_lstm_params->proj_clip();
|
||||||
params->merge_outputs = bidi_lstm_params->merge_outputs();
|
params->merge_outputs = bidi_lstm_params->merge_outputs();
|
||||||
params->time_major = bidi_lstm_params->time_major();
|
params->time_major = bidi_lstm_params->time_major();
|
||||||
|
params->asymmetric_quantize_inputs =
|
||||||
|
bidi_lstm_params->asymmetric_quantize_inputs();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
|
@ -26,6 +26,15 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace rnn {
|
namespace rnn {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int scratch_tensor_index;
|
||||||
|
bool compute_row_sums = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
constexpr int kWeightsTensor = 1;
|
constexpr int kWeightsTensor = 1;
|
||||||
constexpr int kRecurrentWeightsTensor = 2;
|
constexpr int kRecurrentWeightsTensor = 2;
|
||||||
@ -36,13 +45,14 @@ constexpr int kHiddenStateTensor = 4;
|
|||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* scratch_tensor_index = new int;
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
|
context->AddTensors(context, /*tensors_to_add=*/6,
|
||||||
return scratch_tensor_index;
|
&op_data->scratch_tensor_index);
|
||||||
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<int*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -89,10 +99,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Allocate temporary tensors to store quantized values of input and
|
// Allocate temporary tensors to store quantized values of input and
|
||||||
// hidden_state tensors.
|
// hidden_state tensors.
|
||||||
if (is_hybrid) {
|
if (is_hybrid) {
|
||||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
op_data->compute_row_sums = true;
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
node->temporaries = TfLiteIntArrayCreate(3);
|
node->temporaries = TfLiteIntArrayCreate(6);
|
||||||
node->temporaries->data[0] = *scratch_tensor_index;
|
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
input_quantized->type = input_weights->type;
|
input_quantized->type = input_weights->type;
|
||||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||||
@ -101,7 +112,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
||||||
input_quantized_size));
|
input_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[1] = *scratch_tensor_index + 1;
|
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||||
TfLiteTensor* hidden_state_quantized =
|
TfLiteTensor* hidden_state_quantized =
|
||||||
GetTemporary(context, node, /*index=*/1);
|
GetTemporary(context, node, /*index=*/1);
|
||||||
hidden_state_quantized->type = input_weights->type;
|
hidden_state_quantized->type = input_weights->type;
|
||||||
@ -114,7 +125,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->ResizeTensor(context, hidden_state_quantized,
|
context->ResizeTensor(context, hidden_state_quantized,
|
||||||
hidden_state_quantized_size));
|
hidden_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[2] = *scratch_tensor_index + 2;
|
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||||
@ -125,8 +136,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
|
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
|
||||||
|
accum_scratch->type = kTfLiteInt32;
|
||||||
|
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||||
|
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||||
|
accum_scratch_dims)) {
|
||||||
|
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
||||||
|
accum_scratch_size->data[0] = accum_scratch_dims[0];
|
||||||
|
accum_scratch_size->data[1] = accum_scratch_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
||||||
|
accum_scratch_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||||
|
zero_points->type = kTfLiteInt32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
int zero_points_dims[1] = {batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = batch_size;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||||
|
row_sums->type = kTfLiteInt32;
|
||||||
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int row_sums_dims[2] = {2, num_units};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
||||||
|
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
row_sums_size->data[0] = row_sums_dims[0];
|
||||||
|
row_sums_size->data[1] = row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +211,9 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
|
|||||||
TfLiteTensor* input_scratch,
|
TfLiteTensor* input_scratch,
|
||||||
TfLiteTensor* hidden_state_scratch,
|
TfLiteTensor* hidden_state_scratch,
|
||||||
TfLiteTensor* scaling_factors,
|
TfLiteTensor* scaling_factors,
|
||||||
TfLiteTensor* hidden_state, TfLiteTensor* output) {
|
TfLiteTensor* hidden_state, TfLiteTensor* output,
|
||||||
|
TfLiteTensor* zero_points, TfLiteTensor* accum_scratch,
|
||||||
|
TfLiteTensor* row_sums, bool* compute_row_sums) {
|
||||||
const int batch_size = input->dims->data[0];
|
const int batch_size = input->dims->data[0];
|
||||||
const int num_units = input_weights->dims->data[0];
|
const int num_units = input_weights->dims->data[0];
|
||||||
const int input_size = input->dims->data[1];
|
const int input_size = input->dims->data[1];
|
||||||
@ -190,26 +238,34 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
|
|||||||
int8_t* quantized_hidden_state_ptr =
|
int8_t* quantized_hidden_state_ptr =
|
||||||
GetTensorData<int8_t>(hidden_state_scratch);
|
GetTensorData<int8_t>(hidden_state_scratch);
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
|
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
|
||||||
|
int32_t* zero_points_ptr = nullptr;
|
||||||
|
int32_t* row_sums_ptr = nullptr;
|
||||||
|
if (params->asymmetric_quantize_inputs) {
|
||||||
|
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
||||||
|
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||||
|
}
|
||||||
kernel_utils::RnnBatchStep(
|
kernel_utils::RnnBatchStep(
|
||||||
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
||||||
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
||||||
num_units, batch_size, output_batch_leading_dim, params->activation,
|
num_units, batch_size, output_batch_leading_dim, params->activation,
|
||||||
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
hidden_state_ptr_batch, output_ptr_batch);
|
hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr,
|
||||||
|
row_sums_ptr, compute_row_sums);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
|
||||||
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_weights =
|
const TfLiteTensor* recurrent_weights =
|
||||||
GetInput(context, node, kRecurrentWeightsTensor);
|
GetInput(context, node, kRecurrentWeightsTensor);
|
||||||
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
||||||
TfLiteTensor* hidden_state =
|
TfLiteTensor* hidden_state =
|
||||||
GetVariableInput(context, node, kHiddenStateTensor);
|
&context->tensors[node->inputs->data[kHiddenStateTensor]];
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
// We already checked that weight types are consistent, so branch on one.
|
// We already checked that weight types are consistent, so branch on one.
|
||||||
@ -223,9 +279,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
||||||
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
|
||||||
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
||||||
input_quantized, hidden_state_quantized,
|
input_quantized, hidden_state_quantized,
|
||||||
scaling_factors, hidden_state, output);
|
scaling_factors, hidden_state, output, zero_points,
|
||||||
|
accum_scratch, row_sums, &op_data->compute_row_sums);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d not currently supported.",
|
context->ReportError(context, "Type %d not currently supported.",
|
||||||
|
@ -175,7 +175,8 @@ class RNNOpModel : public SingleOpModel {
|
|||||||
public:
|
public:
|
||||||
RNNOpModel(int batches, int units, int size,
|
RNNOpModel(int batches, int units, int size,
|
||||||
const TensorType& weights = TensorType_FLOAT32,
|
const TensorType& weights = TensorType_FLOAT32,
|
||||||
const TensorType& recurrent_weights = TensorType_FLOAT32)
|
const TensorType& recurrent_weights = TensorType_FLOAT32,
|
||||||
|
bool asymmetric_quantize_inputs = false)
|
||||||
: batches_(batches), units_(units), input_size_(size) {
|
: batches_(batches), units_(units), input_size_(size) {
|
||||||
input_ = AddInput(TensorType_FLOAT32);
|
input_ = AddInput(TensorType_FLOAT32);
|
||||||
weights_ = AddInput(weights);
|
weights_ = AddInput(weights);
|
||||||
@ -183,9 +184,10 @@ class RNNOpModel : public SingleOpModel {
|
|||||||
bias_ = AddInput(TensorType_FLOAT32);
|
bias_ = AddInput(TensorType_FLOAT32);
|
||||||
hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
SetBuiltinOp(
|
SetBuiltinOp(BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
|
||||||
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
|
CreateRNNOptions(builder_, ActivationFunctionType_RELU,
|
||||||
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
|
asymmetric_quantize_inputs)
|
||||||
|
.Union());
|
||||||
BuildInterpreter({{batches_, input_size_}, // input tensor
|
BuildInterpreter({{batches_, input_size_}, // input tensor
|
||||||
{units_, input_size_}, // weights tensor
|
{units_, input_size_}, // weights tensor
|
||||||
{units_, units_}, // recurrent weights tensor
|
{units_, units_}, // recurrent weights tensor
|
||||||
@ -233,8 +235,10 @@ class RNNOpModel : public SingleOpModel {
|
|||||||
// The hybrid model has quantized weights and recurrent_weights.
|
// The hybrid model has quantized weights and recurrent_weights.
|
||||||
class HybridRNNOpModel : public RNNOpModel {
|
class HybridRNNOpModel : public RNNOpModel {
|
||||||
public:
|
public:
|
||||||
HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type)
|
HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type,
|
||||||
: RNNOpModel(batches, units, size, tensor_type, tensor_type) {
|
bool asymmetric_quantize_inputs)
|
||||||
|
: RNNOpModel(batches, units, size, tensor_type, tensor_type,
|
||||||
|
asymmetric_quantize_inputs) {
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,8 +286,10 @@ TEST(RnnOpTest, BlackBoxTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridRnnOpTest, BlackBoxTestUint8) {
|
class HybridRnnOpTest : public ::testing::TestWithParam<bool> {};
|
||||||
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8);
|
|
||||||
|
TEST_P(HybridRnnOpTest, BlackBoxTestUint8) {
|
||||||
|
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8, GetParam());
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -310,8 +316,8 @@ TEST(HybridRnnOpTest, BlackBoxTestUint8) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridRnnOpTest, BlackBoxTestInt8) {
|
TEST_P(HybridRnnOpTest, BlackBoxTestInt8) {
|
||||||
HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8);
|
HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8, GetParam());
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -338,5 +344,8 @@ TEST(HybridRnnOpTest, BlackBoxTestInt8) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(HybridRnnOpTest, HybridRnnOpTest,
|
||||||
|
::testing::ValuesIn({false, true}));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -139,18 +139,28 @@ enum TemporaryTensor {
|
|||||||
kProductScalingFactors = 8,
|
kProductScalingFactors = 8,
|
||||||
kRecoveredCellWeights = 9,
|
kRecoveredCellWeights = 9,
|
||||||
kAccumScratchBuffer = 10,
|
kAccumScratchBuffer = 10,
|
||||||
kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input.
|
kZeroPoints = 11,
|
||||||
kNumTemporaryTensors
|
kFwRowSums = 12,
|
||||||
|
kBwRowSums = 13,
|
||||||
|
kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input.
|
||||||
|
kNumTemporaryTensors = 15
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int scratch_tensor_index;
|
||||||
|
bool compute_fw_row_sums = false;
|
||||||
|
bool compute_bw_row_sums = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* scratch_tensor_index = new int;
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
|
context->AddTensors(context, kNumTemporaryTensors,
|
||||||
return scratch_tensor_index;
|
&op_data->scratch_tensor_index);
|
||||||
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<int*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that input tensor dimensions matches with each other.
|
// Check that input tensor dimensions matches with each other.
|
||||||
@ -385,7 +395,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
|||||||
// Resize the output and scratch tensors based on the sizes of the input
|
// Resize the output and scratch tensors based on the sizes of the input
|
||||||
// tensors. Also check that the size of the input tensors match each other.
|
// tensors. Also check that the size of the input tensors match each other.
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
||||||
node->builtin_data);
|
node->builtin_data);
|
||||||
|
|
||||||
@ -522,7 +532,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
|
node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
|
||||||
}
|
}
|
||||||
// Create a scratch buffer tensor.
|
// Create a scratch buffer tensor.
|
||||||
node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
|
node->temporaries->data[kFwScratchBuffer] = op_data->scratch_tensor_index;
|
||||||
TfLiteTensor* fw_scratch_buffer =
|
TfLiteTensor* fw_scratch_buffer =
|
||||||
GetTemporary(context, node, kFwScratchBuffer);
|
GetTemporary(context, node, kFwScratchBuffer);
|
||||||
fw_scratch_buffer->type = input->type;
|
fw_scratch_buffer->type = input->type;
|
||||||
@ -581,7 +591,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Create a scratch buffer tensor.
|
// Create a scratch buffer tensor.
|
||||||
node->temporaries->data[kBwScratchBuffer] =
|
node->temporaries->data[kBwScratchBuffer] =
|
||||||
*(scratch_tensor_index) + kBwScratchBuffer;
|
op_data->scratch_tensor_index + kBwScratchBuffer;
|
||||||
TfLiteTensor* bw_scratch_buffer =
|
TfLiteTensor* bw_scratch_buffer =
|
||||||
GetTemporary(context, node, kBwScratchBuffer);
|
GetTemporary(context, node, kBwScratchBuffer);
|
||||||
bw_scratch_buffer->type = input->type;
|
bw_scratch_buffer->type = input->type;
|
||||||
@ -606,10 +616,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
|
||||||
bw_scratch_buffer_size));
|
bw_scratch_buffer_size));
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
|
// Compute the row sums for cached zero_point offset calculation.
|
||||||
|
op_data->compute_fw_row_sums = true;
|
||||||
|
op_data->compute_bw_row_sums = true;
|
||||||
// Allocate temporary tensors to store quantized values of input, aux_input
|
// Allocate temporary tensors to store quantized values of input, aux_input
|
||||||
// (if present), activation_state and cell_state tensors.
|
// (if present), activation_state and cell_state tensors.
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
*scratch_tensor_index + kInputQuantized;
|
op_data->scratch_tensor_index + kInputQuantized;
|
||||||
TfLiteTensor* input_quantized =
|
TfLiteTensor* input_quantized =
|
||||||
GetTemporary(context, node, kInputQuantized);
|
GetTemporary(context, node, kInputQuantized);
|
||||||
input_quantized->type = fw_input_to_output_weights->type;
|
input_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -621,7 +634,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kFwActivationStateQuantized] =
|
node->temporaries->data[kFwActivationStateQuantized] =
|
||||||
*scratch_tensor_index + kFwActivationStateQuantized;
|
op_data->scratch_tensor_index + kFwActivationStateQuantized;
|
||||||
TfLiteTensor* fw_activation_state_quantized =
|
TfLiteTensor* fw_activation_state_quantized =
|
||||||
GetTemporary(context, node, kFwActivationStateQuantized);
|
GetTemporary(context, node, kFwActivationStateQuantized);
|
||||||
fw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
fw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -635,7 +648,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_activation_state_quantized_size));
|
fw_activation_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kBwActivationStateQuantized] =
|
node->temporaries->data[kBwActivationStateQuantized] =
|
||||||
*scratch_tensor_index + kBwActivationStateQuantized;
|
op_data->scratch_tensor_index + kBwActivationStateQuantized;
|
||||||
TfLiteTensor* bw_activation_state_quantized =
|
TfLiteTensor* bw_activation_state_quantized =
|
||||||
GetTemporary(context, node, kBwActivationStateQuantized);
|
GetTemporary(context, node, kBwActivationStateQuantized);
|
||||||
bw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
bw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -649,7 +662,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bw_activation_state_quantized_size));
|
bw_activation_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kFwCellStateQuantized] =
|
node->temporaries->data[kFwCellStateQuantized] =
|
||||||
*scratch_tensor_index + kFwCellStateQuantized;
|
op_data->scratch_tensor_index + kFwCellStateQuantized;
|
||||||
TfLiteTensor* fw_cell_state_quantized =
|
TfLiteTensor* fw_cell_state_quantized =
|
||||||
GetTemporary(context, node, kFwCellStateQuantized);
|
GetTemporary(context, node, kFwCellStateQuantized);
|
||||||
fw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
fw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -663,7 +676,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_cell_state_quantized_size));
|
fw_cell_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kBwCellStateQuantized] =
|
node->temporaries->data[kBwCellStateQuantized] =
|
||||||
*scratch_tensor_index + kBwCellStateQuantized;
|
op_data->scratch_tensor_index + kBwCellStateQuantized;
|
||||||
TfLiteTensor* bw_cell_state_quantized =
|
TfLiteTensor* bw_cell_state_quantized =
|
||||||
GetTemporary(context, node, kBwCellStateQuantized);
|
GetTemporary(context, node, kBwCellStateQuantized);
|
||||||
bw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
bw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -683,7 +696,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// different matrices (which requires multiplying the scaling factors with
|
// different matrices (which requires multiplying the scaling factors with
|
||||||
// the scaling factor of the matrix).
|
// the scaling factor of the matrix).
|
||||||
node->temporaries->data[kScalingFactors] =
|
node->temporaries->data[kScalingFactors] =
|
||||||
*scratch_tensor_index + kScalingFactors;
|
op_data->scratch_tensor_index + kScalingFactors;
|
||||||
TfLiteTensor* scaling_factors =
|
TfLiteTensor* scaling_factors =
|
||||||
GetTemporary(context, node, kScalingFactors);
|
GetTemporary(context, node, kScalingFactors);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
@ -696,7 +709,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kProductScalingFactors] =
|
node->temporaries->data[kProductScalingFactors] =
|
||||||
*scratch_tensor_index + kProductScalingFactors;
|
op_data->scratch_tensor_index + kProductScalingFactors;
|
||||||
TfLiteTensor* prod_scaling_factors =
|
TfLiteTensor* prod_scaling_factors =
|
||||||
GetTemporary(context, node, kProductScalingFactors);
|
GetTemporary(context, node, kProductScalingFactors);
|
||||||
prod_scaling_factors->type = kTfLiteFloat32;
|
prod_scaling_factors->type = kTfLiteFloat32;
|
||||||
@ -713,7 +726,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Allocate a temporary tensor to store the recovered cell weights. Since
|
// Allocate a temporary tensor to store the recovered cell weights. Since
|
||||||
// this is used for diagonal matrices, only need to store n_cell values.
|
// this is used for diagonal matrices, only need to store n_cell values.
|
||||||
node->temporaries->data[kRecoveredCellWeights] =
|
node->temporaries->data[kRecoveredCellWeights] =
|
||||||
*scratch_tensor_index + kRecoveredCellWeights;
|
op_data->scratch_tensor_index + kRecoveredCellWeights;
|
||||||
TfLiteTensor* recovered_cell_weights =
|
TfLiteTensor* recovered_cell_weights =
|
||||||
GetTemporary(context, node, kRecoveredCellWeights);
|
GetTemporary(context, node, kRecoveredCellWeights);
|
||||||
recovered_cell_weights->type = kTfLiteFloat32;
|
recovered_cell_weights->type = kTfLiteFloat32;
|
||||||
@ -730,7 +743,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Allocate a temporary tensor to store the accumulated int32 values.
|
// Allocate a temporary tensor to store the accumulated int32 values.
|
||||||
node->temporaries->data[kAccumScratchBuffer] =
|
node->temporaries->data[kAccumScratchBuffer] =
|
||||||
*scratch_tensor_index + kAccumScratchBuffer;
|
op_data->scratch_tensor_index + kAccumScratchBuffer;
|
||||||
TfLiteTensor* accum_scratch =
|
TfLiteTensor* accum_scratch =
|
||||||
GetTemporary(context, node, kAccumScratchBuffer);
|
GetTemporary(context, node, kAccumScratchBuffer);
|
||||||
accum_scratch->type = kTfLiteInt32;
|
accum_scratch->type = kTfLiteInt32;
|
||||||
@ -750,11 +763,72 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allocate temporary tensors for storing zero-points.
|
||||||
|
node->temporaries->data[kZeroPoints] =
|
||||||
|
op_data->scratch_tensor_index + kZeroPoints;
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||||
|
zero_points->type = kTfLiteFloat32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = n_batch;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate temporary tensors for caching row sums for hybrid zero-point
|
||||||
|
// calculations.
|
||||||
|
int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
|
||||||
|
if (has_aux_input) {
|
||||||
|
fw_row_sums_rows += fw_use_cifg ? 3 : 4;
|
||||||
|
}
|
||||||
|
const TfLiteTensor* fw_projection_weights =
|
||||||
|
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
|
||||||
|
if (fw_projection_weights != nullptr) {
|
||||||
|
fw_row_sums_rows += ceil(n_fw_output / n_fw_cell);
|
||||||
|
}
|
||||||
|
node->temporaries->data[kFwRowSums] =
|
||||||
|
op_data->scratch_tensor_index + kFwRowSums;
|
||||||
|
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||||
|
fw_row_sums->type = kTfLiteInt32;
|
||||||
|
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
|
||||||
|
TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
|
||||||
|
fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
|
||||||
|
fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
|
||||||
|
fw_hybrid_scratch_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
|
||||||
|
if (has_aux_input) {
|
||||||
|
bw_row_sums_rows += bw_use_cifg ? 3 : 4;
|
||||||
|
}
|
||||||
|
const TfLiteTensor* bw_projection_weights =
|
||||||
|
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
|
||||||
|
if (bw_projection_weights != nullptr) {
|
||||||
|
bw_row_sums_rows += ceil(n_bw_output / n_bw_cell);
|
||||||
|
}
|
||||||
|
node->temporaries->data[kBwRowSums] =
|
||||||
|
op_data->scratch_tensor_index + kBwRowSums;
|
||||||
|
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||||
|
bw_row_sums->type = kTfLiteInt32;
|
||||||
|
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
|
||||||
|
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
|
||||||
|
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
|
||||||
|
bw_row_sums_size));
|
||||||
|
}
|
||||||
|
|
||||||
// Only allocate a temporary tensor for quantized auxiliary input if we are
|
// Only allocate a temporary tensor for quantized auxiliary input if we are
|
||||||
// actually going to use it.
|
// actually going to use it.
|
||||||
if (has_aux_input) {
|
if (has_aux_input) {
|
||||||
node->temporaries->data[kAuxInputQuantized] =
|
node->temporaries->data[kAuxInputQuantized] =
|
||||||
*scratch_tensor_index + kAuxInputQuantized;
|
op_data->scratch_tensor_index + kAuxInputQuantized;
|
||||||
TfLiteTensor* aux_input_quantized =
|
TfLiteTensor* aux_input_quantized =
|
||||||
GetTemporary(context, node, kAuxInputQuantized);
|
GetTemporary(context, node, kAuxInputQuantized);
|
||||||
aux_input_quantized->type = fw_input_to_output_weights->type;
|
aux_input_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -775,7 +849,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
||||||
node->builtin_data);
|
node->builtin_data);
|
||||||
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
// Input tensor.
|
// Input tensor.
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
|
||||||
@ -909,7 +983,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Populate a TfLiteLSTMParams struct for the evaluation functions.
|
// Populate a TfLiteLSTMParams struct for the evaluation functions.
|
||||||
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
|
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
|
||||||
params->proj_clip, kTfLiteLSTMFullKernel};
|
params->proj_clip, kTfLiteLSTMFullKernel,
|
||||||
|
params->asymmetric_quantize_inputs};
|
||||||
|
|
||||||
const int bw_output_offset =
|
const int bw_output_offset =
|
||||||
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
|
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
|
||||||
@ -1003,7 +1078,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
: nullptr;
|
: nullptr;
|
||||||
TfLiteTensor* accum_scratch =
|
TfLiteTensor* accum_scratch =
|
||||||
GetTemporary(context, node, kAccumScratchBuffer);
|
GetTemporary(context, node, kAccumScratchBuffer);
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||||
|
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||||
|
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||||
|
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
||||||
|
const int bw_row_sums_size = bw_row_sums->dims->data[0];
|
||||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||||
fw_input_to_cell_weights, fw_input_to_output_weights,
|
fw_input_to_cell_weights, fw_input_to_output_weights,
|
||||||
@ -1025,6 +1104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||||
fw_activation_state_quantized, fw_cell_state_quantized,
|
fw_activation_state_quantized, fw_cell_state_quantized,
|
||||||
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
|
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
|
||||||
|
zero_points, fw_row_sums, fw_row_sums_size,
|
||||||
|
&op_data->compute_fw_row_sums,
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||||
|
|
||||||
@ -1049,6 +1130,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||||
bw_activation_state_quantized, bw_cell_state_quantized,
|
bw_activation_state_quantized, bw_cell_state_quantized,
|
||||||
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
|
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
|
||||||
|
zero_points, bw_row_sums, bw_row_sums_size,
|
||||||
|
&op_data->compute_bw_row_sums,
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
@ -40,7 +40,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
bool use_projection_bias, bool merge_outputs,
|
bool use_projection_bias, bool merge_outputs,
|
||||||
bool use_aux_input, float cell_clip, float proj_clip,
|
bool use_aux_input, float cell_clip, float proj_clip,
|
||||||
bool quantize_weights, bool time_major,
|
bool quantize_weights, bool time_major,
|
||||||
const std::vector<std::vector<int>>& input_shapes)
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
|
bool asymmetric_quantize_inputs = false)
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_fw_cell_(n_cell),
|
n_fw_cell_(n_cell),
|
||||||
@ -207,12 +208,13 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
bw_aux_input_to_output_weights_ = AddNullInput();
|
bw_aux_input_to_output_weights_ = AddNullInput();
|
||||||
}
|
}
|
||||||
|
|
||||||
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
|
SetBuiltinOp(
|
||||||
BuiltinOptions_BidirectionalSequenceLSTMOptions,
|
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
|
||||||
CreateBidirectionalSequenceLSTMOptions(
|
BuiltinOptions_BidirectionalSequenceLSTMOptions,
|
||||||
builder_, ActivationFunctionType_TANH, cell_clip,
|
CreateBidirectionalSequenceLSTMOptions(
|
||||||
proj_clip, merge_outputs, time_major)
|
builder_, ActivationFunctionType_TANH, cell_clip, proj_clip,
|
||||||
.Union());
|
merge_outputs, time_major, asymmetric_quantize_inputs)
|
||||||
|
.Union());
|
||||||
BuildInterpreter(input_shapes);
|
BuildInterpreter(input_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -424,11 +426,14 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
bool quantize_weights_;
|
bool quantize_weights_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Declare LSTMOpTest as a parameterized test, where the parameter is a boolean
|
// Declare LSTMOpTest as a parameterized test.
|
||||||
// indicating whether to use quantization or not.
|
class LSTMOpTest
|
||||||
class LSTMOpTest : public ::testing::TestWithParam<bool> {};
|
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool());
|
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest,
|
||||||
|
::testing::Combine(
|
||||||
|
/*quantize_weights*/ ::testing::Bool(),
|
||||||
|
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
|
||||||
|
|
||||||
TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
@ -437,7 +442,9 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
const bool quantize_weights = GetParam();
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -509,7 +516,8 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
|||||||
{0}, // aux_bw_input_to_forget tensor
|
{0}, // aux_bw_input_to_forget tensor
|
||||||
{0}, // aux_bw_input_to_cell tensor
|
{0}, // aux_bw_input_to_cell tensor
|
||||||
{0}, // aux_bw_input_to_output tensor
|
{0}, // aux_bw_input_to_output tensor
|
||||||
});
|
},
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -600,7 +608,9 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
const bool quantize_weights = GetParam();
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -672,7 +682,8 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
|
|||||||
{0}, // aux_bw_input_to_forget tensor
|
{0}, // aux_bw_input_to_forget tensor
|
||||||
{0}, // aux_bw_input_to_cell tensor
|
{0}, // aux_bw_input_to_cell tensor
|
||||||
{0}, // aux_bw_input_to_output tensor
|
{0}, // aux_bw_input_to_output tensor
|
||||||
});
|
},
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -2631,7 +2642,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
const bool quantize_weights = GetParam();
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -2703,7 +2716,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
|
|||||||
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
||||||
});
|
},
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -2802,7 +2816,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
const bool quantize_weights = GetParam();
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -2874,7 +2890,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
|
|||||||
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
||||||
});
|
},
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
|
@ -27,6 +27,16 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace bidirectional_sequence_rnn {
|
namespace bidirectional_sequence_rnn {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int scratch_tensor_index;
|
||||||
|
bool fw_compute_row_sums = false;
|
||||||
|
bool bw_compute_row_sums = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
@ -58,18 +68,23 @@ enum TemporaryTensor {
|
|||||||
kFwHiddenStateQuantized = 1,
|
kFwHiddenStateQuantized = 1,
|
||||||
kBwHiddenStateQuantized = 2,
|
kBwHiddenStateQuantized = 2,
|
||||||
kScalingFactors = 3,
|
kScalingFactors = 3,
|
||||||
kAuxInputQuantized = 4,
|
kAccumScratch = 4,
|
||||||
kNumTemporaryTensors = 5
|
kZeroPoints = 5,
|
||||||
|
kFwRowSums = 6,
|
||||||
|
kBwRowSums = 7,
|
||||||
|
kAuxInputQuantized = 8,
|
||||||
|
kNumTemporaryTensors = 9
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* scratch_tensor_index = new int;
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
|
context->AddTensors(context, kNumTemporaryTensors,
|
||||||
return scratch_tensor_index;
|
&op_data->scratch_tensor_index);
|
||||||
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<int*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -157,8 +172,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (IsHybridOp(input, fw_input_weights)) {
|
if (IsHybridOp(input, fw_input_weights)) {
|
||||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
op_data->fw_compute_row_sums = true;
|
||||||
|
op_data->bw_compute_row_sums = true;
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (has_aux_input) {
|
if (has_aux_input) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
|
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
|
||||||
@ -168,7 +184,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
*scratch_tensor_index + kInputQuantized;
|
op_data->scratch_tensor_index + kInputQuantized;
|
||||||
TfLiteTensor* input_quantized =
|
TfLiteTensor* input_quantized =
|
||||||
GetTemporary(context, node, kInputQuantized);
|
GetTemporary(context, node, kInputQuantized);
|
||||||
input_quantized->type = fw_input_weights->type;
|
input_quantized->type = fw_input_weights->type;
|
||||||
@ -180,7 +196,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kFwHiddenStateQuantized] =
|
node->temporaries->data[kFwHiddenStateQuantized] =
|
||||||
*scratch_tensor_index + kFwHiddenStateQuantized;
|
op_data->scratch_tensor_index + kFwHiddenStateQuantized;
|
||||||
TfLiteTensor* fw_hidden_state_quantized =
|
TfLiteTensor* fw_hidden_state_quantized =
|
||||||
GetTemporary(context, node, kFwHiddenStateQuantized);
|
GetTemporary(context, node, kFwHiddenStateQuantized);
|
||||||
fw_hidden_state_quantized->type = fw_input_weights->type;
|
fw_hidden_state_quantized->type = fw_input_weights->type;
|
||||||
@ -195,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kBwHiddenStateQuantized] =
|
node->temporaries->data[kBwHiddenStateQuantized] =
|
||||||
*scratch_tensor_index + kBwHiddenStateQuantized;
|
op_data->scratch_tensor_index + kBwHiddenStateQuantized;
|
||||||
TfLiteTensor* bw_hidden_state_quantized =
|
TfLiteTensor* bw_hidden_state_quantized =
|
||||||
GetTemporary(context, node, kBwHiddenStateQuantized);
|
GetTemporary(context, node, kBwHiddenStateQuantized);
|
||||||
bw_hidden_state_quantized->type = fw_input_weights->type;
|
bw_hidden_state_quantized->type = fw_input_weights->type;
|
||||||
@ -211,7 +227,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Allocate temporary tensors to store scaling factors of quantization.
|
// Allocate temporary tensors to store scaling factors of quantization.
|
||||||
node->temporaries->data[kScalingFactors] =
|
node->temporaries->data[kScalingFactors] =
|
||||||
*scratch_tensor_index + kScalingFactors;
|
op_data->scratch_tensor_index + kScalingFactors;
|
||||||
TfLiteTensor* scaling_factors =
|
TfLiteTensor* scaling_factors =
|
||||||
GetTemporary(context, node, kScalingFactors);
|
GetTemporary(context, node, kScalingFactors);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
@ -223,10 +239,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
|
node->temporaries->data[kAccumScratch] =
|
||||||
|
op_data->scratch_tensor_index + kAccumScratch;
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||||
|
accum_scratch->type = kTfLiteInt32;
|
||||||
|
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||||
|
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
|
||||||
|
batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||||
|
accum_scratch_dims)) {
|
||||||
|
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
||||||
|
accum_scratch_size->data[0] = accum_scratch_dims[0];
|
||||||
|
accum_scratch_size->data[1] = accum_scratch_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
||||||
|
accum_scratch_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[kZeroPoints] =
|
||||||
|
op_data->scratch_tensor_index + kZeroPoints;
|
||||||
|
TfLiteTensor* zero_points =
|
||||||
|
GetTemporary(context, node, /*index=*/kZeroPoints);
|
||||||
|
zero_points->type = kTfLiteInt32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
int zero_points_dims[1] = {batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = batch_size;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
const int num_row_sums = has_aux_input ? 3 : 2;
|
||||||
|
node->temporaries->data[kFwRowSums] =
|
||||||
|
op_data->scratch_tensor_index + kFwRowSums;
|
||||||
|
TfLiteTensor* fw_row_sums =
|
||||||
|
GetTemporary(context, node, /*index=*/kFwRowSums);
|
||||||
|
fw_row_sums->type = kTfLiteInt32;
|
||||||
|
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
|
||||||
|
TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
fw_row_sums_size->data[0] = fw_row_sums_dims[0];
|
||||||
|
fw_row_sums_size->data[1] = fw_row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
|
||||||
|
fw_row_sums_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[kBwRowSums] =
|
||||||
|
op_data->scratch_tensor_index + kBwRowSums;
|
||||||
|
TfLiteTensor* bw_row_sums = GetTemporary(context, node,
|
||||||
|
/*index=*/kBwRowSums);
|
||||||
|
bw_row_sums->type = kTfLiteInt32;
|
||||||
|
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
|
||||||
|
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
|
||||||
|
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
|
||||||
|
bw_row_sums_size));
|
||||||
|
}
|
||||||
if (has_aux_input) {
|
if (has_aux_input) {
|
||||||
node->temporaries->data[kAuxInputQuantized] =
|
node->temporaries->data[kAuxInputQuantized] =
|
||||||
*scratch_tensor_index + kAuxInputQuantized;
|
op_data->scratch_tensor_index + kAuxInputQuantized;
|
||||||
TfLiteTensor* aux_input_quantized =
|
TfLiteTensor* aux_input_quantized =
|
||||||
GetTemporary(context, node, kAuxInputQuantized);
|
GetTemporary(context, node, kAuxInputQuantized);
|
||||||
aux_input_quantized->type = fw_input_weights->type;
|
aux_input_quantized->type = fw_input_weights->type;
|
||||||
@ -418,7 +490,10 @@ TfLiteStatus EvalHybrid(
|
|||||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
|
||||||
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
|
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
|
||||||
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
|
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
|
||||||
TfLiteTensor* bw_output) {
|
TfLiteTensor* bw_output, TfLiteTensor* zero_points,
|
||||||
|
TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
|
||||||
|
TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
|
||||||
|
bool* bw_compute_row_sums) {
|
||||||
const bool time_major = params->time_major;
|
const bool time_major = params->time_major;
|
||||||
const int batch_size =
|
const int batch_size =
|
||||||
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
||||||
@ -464,11 +539,20 @@ TfLiteStatus EvalHybrid(
|
|||||||
int8_t* bw_quantized_hidden_state_ptr =
|
int8_t* bw_quantized_hidden_state_ptr =
|
||||||
GetTensorData<int8_t>(bw_hidden_state_quantized);
|
GetTensorData<int8_t>(bw_hidden_state_quantized);
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
|
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
|
||||||
|
int32_t* zero_points_ptr = nullptr;
|
||||||
|
int32_t* fw_row_sums_ptr = nullptr;
|
||||||
|
int32_t* bw_row_sums_ptr = nullptr;
|
||||||
|
if (params->asymmetric_quantize_inputs) {
|
||||||
|
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
||||||
|
fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
|
||||||
|
bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
|
||||||
|
}
|
||||||
const int fw_output_step =
|
const int fw_output_step =
|
||||||
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
|
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
|
||||||
const int bw_output_step =
|
const int bw_output_step =
|
||||||
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
|
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
for (int t = 0; t < max_time; t++) {
|
for (int t = 0; t < max_time; t++) {
|
||||||
// Forward cell.
|
// Forward cell.
|
||||||
@ -491,7 +575,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
fw_num_units, batch_size, fw_output_step, params->activation,
|
fw_num_units, batch_size, fw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
fw_hidden_state_ptr_batch, output_ptr_batch);
|
fw_hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr,
|
||||||
|
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
|
||||||
}
|
}
|
||||||
// Backward cell.
|
// Backward cell.
|
||||||
float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
|
float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
|
||||||
@ -516,7 +602,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
bw_num_units, batch_size, bw_output_step, params->activation,
|
bw_num_units, batch_size, bw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
bw_hidden_state_ptr_batch, output_ptr_batch);
|
bw_hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr,
|
||||||
|
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -545,7 +633,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
|
fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
fw_hidden_state_ptr_batch, output_ptr_batch);
|
fw_hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr,
|
||||||
|
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
|
||||||
}
|
}
|
||||||
// Backward cell.
|
// Backward cell.
|
||||||
float* bw_hidden_state_ptr_batch =
|
float* bw_hidden_state_ptr_batch =
|
||||||
@ -574,7 +664,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
|
bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
bw_hidden_state_ptr_batch, output_ptr_batch);
|
bw_hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr,
|
||||||
|
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -656,17 +748,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, kBwHiddenStateQuantized);
|
GetTemporary(context, node, kBwHiddenStateQuantized);
|
||||||
TfLiteTensor* scaling_factors =
|
TfLiteTensor* scaling_factors =
|
||||||
GetTemporary(context, node, kScalingFactors);
|
GetTemporary(context, node, kScalingFactors);
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||||
|
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||||
|
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||||
TfLiteTensor* aux_input_quantized =
|
TfLiteTensor* aux_input_quantized =
|
||||||
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights,
|
return EvalHybrid(
|
||||||
fw_bias, bw_input_weights, bw_recurrent_weights,
|
input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
|
||||||
bw_bias, real_aux_input, fw_aux_input_weights,
|
bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
|
||||||
bw_aux_input_weights, params, scaling_factors,
|
fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
|
||||||
input_quantized, aux_input_quantized,
|
input_quantized, aux_input_quantized, fw_hidden_state_quantized,
|
||||||
fw_hidden_state_quantized, fw_hidden_state, fw_output,
|
fw_hidden_state, fw_output, bw_hidden_state_quantized,
|
||||||
bw_hidden_state_quantized, bw_hidden_state, bw_output);
|
bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
|
||||||
|
bw_row_sums, &op_data->fw_compute_row_sums,
|
||||||
|
&op_data->bw_compute_row_sums);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type not currently supported.");
|
context->ReportError(context, "Type not currently supported.");
|
||||||
|
@ -662,20 +662,24 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
|
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
|
||||||
int bw_units, int input_size, int aux_input_size,
|
int bw_units, int input_size, int aux_input_size,
|
||||||
AuxInputMode aux_input_mode, bool time_major,
|
AuxInputMode aux_input_mode, bool time_major,
|
||||||
bool merge_outputs)
|
bool merge_outputs, bool quantize_weights = false,
|
||||||
|
bool asymmetric_quantize_weights = false)
|
||||||
: batches_(batches),
|
: batches_(batches),
|
||||||
sequence_len_(sequence_len),
|
sequence_len_(sequence_len),
|
||||||
fw_units_(fw_units),
|
fw_units_(fw_units),
|
||||||
bw_units_(bw_units),
|
bw_units_(bw_units),
|
||||||
input_size_(input_size),
|
input_size_(input_size),
|
||||||
aux_input_size_(aux_input_size) {
|
aux_input_size_(aux_input_size),
|
||||||
|
quantize_weights_(quantize_weights) {
|
||||||
|
const TensorType tensor_type =
|
||||||
|
quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
|
||||||
input_ = AddInput(TensorType_FLOAT32);
|
input_ = AddInput(TensorType_FLOAT32);
|
||||||
fw_weights_ = AddInput(TensorType_FLOAT32);
|
fw_weights_ = AddInput(tensor_type);
|
||||||
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
|
fw_recurrent_weights_ = AddInput(tensor_type);
|
||||||
fw_bias_ = AddInput(TensorType_FLOAT32);
|
fw_bias_ = AddInput(TensorType_FLOAT32);
|
||||||
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
||||||
bw_weights_ = AddInput(TensorType_FLOAT32);
|
bw_weights_ = AddInput(tensor_type);
|
||||||
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
|
bw_recurrent_weights_ = AddInput(tensor_type);
|
||||||
bw_bias_ = AddInput(TensorType_FLOAT32);
|
bw_bias_ = AddInput(TensorType_FLOAT32);
|
||||||
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
||||||
|
|
||||||
@ -697,8 +701,8 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (aux_input_mode == AuxInputMode::kCrossLinking) {
|
if (aux_input_mode == AuxInputMode::kCrossLinking) {
|
||||||
aux_fw_weights_ = AddInput(TensorType_FLOAT32);
|
aux_fw_weights_ = AddInput(tensor_type);
|
||||||
aux_bw_weights_ = AddInput(TensorType_FLOAT32);
|
aux_bw_weights_ = AddInput(tensor_type);
|
||||||
|
|
||||||
aux_fw_weights_shape = {fw_units, aux_input_size_};
|
aux_fw_weights_shape = {fw_units, aux_input_size_};
|
||||||
aux_bw_weights_shape = {bw_units, aux_input_size_};
|
aux_bw_weights_shape = {bw_units, aux_input_size_};
|
||||||
@ -712,12 +716,12 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
bw_output_ = AddOutput(TensorType_FLOAT32);
|
bw_output_ = AddOutput(TensorType_FLOAT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
SetBuiltinOp(
|
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
|
||||||
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
|
BuiltinOptions_BidirectionalSequenceRNNOptions,
|
||||||
BuiltinOptions_BidirectionalSequenceRNNOptions,
|
CreateBidirectionalSequenceRNNOptions(
|
||||||
CreateBidirectionalSequenceRNNOptions(
|
builder_, time_major, ActivationFunctionType_RELU,
|
||||||
builder_, time_major, ActivationFunctionType_RELU, merge_outputs)
|
merge_outputs, asymmetric_quantize_weights)
|
||||||
.Union());
|
.Union());
|
||||||
|
|
||||||
BuildInterpreter({
|
BuildInterpreter({
|
||||||
input_shape, // input
|
input_shape, // input
|
||||||
@ -744,19 +748,35 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetFwWeights(const std::vector<float>& f) {
|
void SetFwWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(fw_weights_, f);
|
if (quantize_weights_) {
|
||||||
|
SymmetricQuantizeAndPopulate(fw_weights_, f);
|
||||||
|
} else {
|
||||||
|
PopulateTensor(fw_weights_, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBwWeights(const std::vector<float>& f) {
|
void SetBwWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(bw_weights_, f);
|
if (quantize_weights_) {
|
||||||
|
SymmetricQuantizeAndPopulate(bw_weights_, f);
|
||||||
|
} else {
|
||||||
|
PopulateTensor(bw_weights_, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetFwRecurrentWeights(const std::vector<float>& f) {
|
void SetFwRecurrentWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(fw_recurrent_weights_, f);
|
if (quantize_weights_) {
|
||||||
|
SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
|
||||||
|
} else {
|
||||||
|
PopulateTensor(fw_recurrent_weights_, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBwRecurrentWeights(const std::vector<float>& f) {
|
void SetBwRecurrentWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(bw_recurrent_weights_, f);
|
if (quantize_weights_) {
|
||||||
|
SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
|
||||||
|
} else {
|
||||||
|
PopulateTensor(bw_recurrent_weights_, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInput(std::initializer_list<float> data) {
|
void SetInput(std::initializer_list<float> data) {
|
||||||
@ -772,11 +792,19 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetAuxFwWeights(const std::vector<float>& f) {
|
void SetAuxFwWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(aux_fw_weights_, f);
|
if (quantize_weights_) {
|
||||||
|
SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
|
||||||
|
} else {
|
||||||
|
PopulateTensor(aux_fw_weights_, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetAuxBwWeights(const std::vector<float>& f) {
|
void SetAuxBwWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(aux_bw_weights_, f);
|
if (quantize_weights_) {
|
||||||
|
SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
|
||||||
|
} else {
|
||||||
|
PopulateTensor(aux_bw_weights_, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
|
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
|
||||||
@ -811,17 +839,31 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
int bw_units_;
|
int bw_units_;
|
||||||
int input_size_;
|
int input_size_;
|
||||||
int aux_input_size_;
|
int aux_input_size_;
|
||||||
|
bool quantize_weights_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Declare LSTMOpTest as a parameterized test.
|
||||||
|
class BidirectionalRNNOpTest
|
||||||
|
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
|
||||||
|
::testing::Combine(
|
||||||
|
/*quantize_weights*/ ::testing::Bool(),
|
||||||
|
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
|
||||||
|
|
||||||
// TODO(mirkov): add another test which directly compares to TF once TOCO
|
// TODO(mirkov): add another test which directly compares to TF once TOCO
|
||||||
// supports the conversion from dynamic_rnn with BasicRNNCell.
|
// supports the conversion from dynamic_rnn with BasicRNNCell.
|
||||||
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
|
TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
|
||||||
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*fw_units=*/16, /*bw_units=*/16,
|
/*fw_units=*/16, /*bw_units=*/16,
|
||||||
/*input_size=*/8, /*aux_input_size=*/0,
|
/*input_size=*/8, /*aux_input_size=*/0,
|
||||||
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
||||||
/*time_major=*/false,
|
/*time_major=*/false,
|
||||||
/*merge_outputs=*/false);
|
/*merge_outputs=*/false, quantize_weights,
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
rnn.SetFwWeights(weights);
|
rnn.SetFwWeights(weights);
|
||||||
rnn.SetBwWeights(weights);
|
rnn.SetBwWeights(weights);
|
||||||
rnn.SetFwBias(biases);
|
rnn.SetFwBias(biases);
|
||||||
@ -843,7 +885,9 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
|
|||||||
std::vector<float> fw_expected;
|
std::vector<float> fw_expected;
|
||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
EXPECT_THAT(rnn.GetFwOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
||||||
|
|
||||||
float* golden_bw_start = rnn_golden_bw_output;
|
float* golden_bw_start = rnn_golden_bw_output;
|
||||||
float* golden_bw_end =
|
float* golden_bw_end =
|
||||||
@ -851,17 +895,23 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
|
|||||||
std::vector<float> bw_expected;
|
std::vector<float> bw_expected;
|
||||||
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
||||||
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
||||||
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
EXPECT_THAT(rnn.GetBwOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as BlackBox test, but input is reshuffled to time_major format.
|
// Same as BlackBox test, but input is reshuffled to time_major format.
|
||||||
TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
||||||
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*fw_units=*/16, /*bw_units=*/16,
|
/*fw_units=*/16, /*bw_units=*/16,
|
||||||
/*input_size=*/8, /*aux_input_size=*/0,
|
/*input_size=*/8, /*aux_input_size=*/0,
|
||||||
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
||||||
/*time_major=*/true,
|
/*time_major=*/true,
|
||||||
/*merge_outputs=*/false);
|
/*merge_outputs=*/false, quantize_weights,
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
rnn.SetFwWeights(weights);
|
rnn.SetFwWeights(weights);
|
||||||
rnn.SetBwWeights(weights);
|
rnn.SetBwWeights(weights);
|
||||||
rnn.SetFwBias(biases);
|
rnn.SetFwBias(biases);
|
||||||
@ -889,17 +939,26 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
|||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
}
|
}
|
||||||
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
constexpr float kHybridTolerance = 3.57e-1;
|
||||||
|
constexpr float kFloatTolerance = 1e-5;
|
||||||
|
EXPECT_THAT(
|
||||||
|
rnn.GetFwOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as BlackBox test, yet with merged outputs.
|
// Same as BlackBox test, yet with merged outputs.
|
||||||
TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
||||||
|
auto params = GetParam();
|
||||||
|
const bool quantize_weights = std::get<0>(params);
|
||||||
|
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
||||||
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*fw_units=*/16, /*bw_units=*/16,
|
/*fw_units=*/16, /*bw_units=*/16,
|
||||||
/*input_size=*/8, /*aux_input_size=*/0,
|
/*input_size=*/8, /*aux_input_size=*/0,
|
||||||
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
||||||
/*time_major=*/false,
|
/*time_major=*/false,
|
||||||
/*merge_outputs=*/true);
|
/*merge_outputs=*/true, quantize_weights,
|
||||||
|
asymmetric_quantize_inputs);
|
||||||
rnn.SetFwWeights(weights);
|
rnn.SetFwWeights(weights);
|
||||||
rnn.SetBwWeights(weights);
|
rnn.SetBwWeights(weights);
|
||||||
rnn.SetFwBias(biases);
|
rnn.SetFwBias(biases);
|
||||||
@ -929,7 +988,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_THAT(rnn.GetFwOutput(),
|
EXPECT_THAT(rnn.GetFwOutput(),
|
||||||
ElementsAreArray(ArrayFloatNear(merged_expected)));
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as BlackBox test, but input is reshuffled to time_major format.
|
// Same as BlackBox test, but input is reshuffled to time_major format.
|
||||||
|
@ -71,6 +71,7 @@ struct OpData {
|
|||||||
int32_t output_activation_max;
|
int32_t output_activation_max;
|
||||||
// The index of the temporary tensor where the quantized inputs are cached.
|
// The index of the temporary tensor where the quantized inputs are cached.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
|
bool compute_row_sums = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
@ -131,7 +132,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
// Instead, we allocate a new object to carry information from Prepare() to
|
// Instead, we allocate a new object to carry information from Prepare() to
|
||||||
// Eval().
|
// Eval().
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, /*tensors_to_add=*/3,
|
context->AddTensors(context, /*tensors_to_add=*/5,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -144,7 +145,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
auto* params =
|
auto* params =
|
||||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
// Check we have all the inputs and outputs we need.
|
// Check we have all the inputs and outputs we need.
|
||||||
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
|
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
|
||||||
// Shuffled formats need a workspace to store the shuffled input activations.
|
// Shuffled formats need a workspace to store the shuffled input activations.
|
||||||
@ -208,7 +208,8 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
if (input->type == kTfLiteFloat32 &&
|
if (input->type == kTfLiteFloat32 &&
|
||||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
|
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
node->temporaries = TfLiteIntArrayCreate(3);
|
data->compute_row_sums = true;
|
||||||
|
node->temporaries = TfLiteIntArrayCreate(5);
|
||||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||||
|
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
@ -245,6 +246,28 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node->temporaries->data[3] = data->scratch_tensor_index + 3;
|
||||||
|
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
|
||||||
|
input_offsets->type = kTfLiteInt32;
|
||||||
|
input_offsets->allocation_type = kTfLiteArenaRw;
|
||||||
|
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
|
||||||
|
TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
|
||||||
|
input_offsets_size->data[0] = batch_size;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
|
||||||
|
input_offsets_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[4] = data->scratch_tensor_index + 4;
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
|
||||||
|
row_sums->type = kTfLiteInt32;
|
||||||
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int row_sums_dims[1] = {num_units};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
|
||||||
|
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
|
||||||
|
row_sums_size->data[0] = row_sums_dims[0];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resize output.
|
// Resize output.
|
||||||
@ -332,7 +355,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
|||||||
TfLiteFullyConnectedParams* params, OpData* data,
|
TfLiteFullyConnectedParams* params, OpData* data,
|
||||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||||
const TfLiteTensor* bias, TfLiteTensor* input_quantized,
|
const TfLiteTensor* bias, TfLiteTensor* input_quantized,
|
||||||
TfLiteTensor* scaling_factors, TfLiteTensor* output) {
|
TfLiteTensor* scaling_factors,
|
||||||
|
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
||||||
|
TfLiteTensor* input_offsets, TfLiteTensor* output) {
|
||||||
int total_input_size = 1;
|
int total_input_size = 1;
|
||||||
for (int i = 0; i < input->dims->size; i++) {
|
for (int i = 0; i < input->dims->size; i++) {
|
||||||
total_input_size *= input->dims->data[i];
|
total_input_size *= input->dims->data[i];
|
||||||
@ -363,32 +388,39 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
|||||||
// Quantize input from float to uint8 + quantization params (scaling factor).
|
// Quantize input from float to uint8 + quantization params (scaling factor).
|
||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
|
int32_t* input_offset_ptr = nullptr;
|
||||||
|
int32_t* row_sums_ptr = nullptr;
|
||||||
|
if (params->asymmetric_quantize_inputs) {
|
||||||
|
input_offset_ptr = GetTensorData<int32_t>(input_offsets);
|
||||||
|
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||||
|
}
|
||||||
int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
|
int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
|
||||||
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
||||||
|
const float* input_ptr = GetTensorData<float>(input);
|
||||||
// Quantize each batch independently.
|
// Quantize each batch independently.
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (params->asymmetric_quantize_inputs) {
|
||||||
GetTensorData<float>(input) + offset, input_size, quant_data + offset,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
&unused_min, &unused_max, &scaling_factors_ptr[b]);
|
input_ptr + offset, input_size, quant_data + offset,
|
||||||
|
&scaling_factors_ptr[b], &input_offset_ptr[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
input_ptr + offset, input_size, quant_data + offset, &unused_min,
|
||||||
|
&unused_max, &scaling_factors_ptr[b]);
|
||||||
|
}
|
||||||
// Incorporate scaling of the filter.
|
// Incorporate scaling of the filter.
|
||||||
scaling_factors_ptr[b] *= filter->params.scale;
|
scaling_factors_ptr[b] *= filter->params.scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute output += weight * quantized_input
|
// Compute output += weight * quantized_input
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
|
||||||
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||||
batch_size, scratch, GetTensorData<float>(output),
|
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
|
||||||
|
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
#else
|
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
|
||||||
batch_size, GetTensorData<float>(output));
|
|
||||||
#endif
|
|
||||||
// Apply activation function to floats.
|
// Apply activation function to floats.
|
||||||
tensor_utils::ApplyActivationToVector(
|
tensor_utils::ApplyActivationToVector(
|
||||||
GetTensorData<float>(output), batch_size * num_units, params->activation,
|
GetTensorData<float>(output), batch_size * num_units, params->activation,
|
||||||
@ -461,8 +493,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
if (input->type == kTfLiteFloat32) {
|
if (input->type == kTfLiteFloat32) {
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
||||||
|
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
|
||||||
return EvalHybrid(context, node, params, data, input, filter, bias,
|
return EvalHybrid(context, node, params, data, input, filter, bias,
|
||||||
input_quantized, scaling_factors, output);
|
input_quantized, scaling_factors, accum_scratch, row_sums,
|
||||||
|
input_offsets, output);
|
||||||
} else {
|
} else {
|
||||||
FullyConnectedParams op_params;
|
FullyConnectedParams op_params;
|
||||||
op_params.input_offset = input_offset;
|
op_params.input_offset = input_offset;
|
||||||
@ -590,7 +626,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|||||||
FullyConnectedParams op_params;
|
FullyConnectedParams op_params;
|
||||||
op_params.float_activation_min = output_activation_min;
|
op_params.float_activation_min = output_activation_min;
|
||||||
op_params.float_activation_max = output_activation_max;
|
op_params.float_activation_max = output_activation_max;
|
||||||
|
|
||||||
reference_ops::FullyConnected(
|
reference_ops::FullyConnected(
|
||||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||||
|
@ -286,7 +286,8 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
|
|||||||
public:
|
public:
|
||||||
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
|
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
|
||||||
const TensorData& weights,
|
const TensorData& weights,
|
||||||
const TensorData& output = {TensorType_FLOAT32})
|
const TensorData& output = {TensorType_FLOAT32},
|
||||||
|
bool asymmetric_inputs = false)
|
||||||
: batches_(batches), units_(units) {
|
: batches_(batches), units_(units) {
|
||||||
int total_input_size = 1;
|
int total_input_size = 1;
|
||||||
for (size_t i = 0; i < input.shape.size(); ++i) {
|
for (size_t i = 0; i < input.shape.size(); ++i) {
|
||||||
@ -302,10 +303,13 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
output_ = AddOutput(output);
|
output_ = AddOutput(output);
|
||||||
|
|
||||||
SetBuiltinOp(
|
auto options = CreateFullyConnectedOptions(
|
||||||
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
|
builder_, ActivationFunctionType_RELU,
|
||||||
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
|
tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
|
||||||
.Union());
|
false, asymmetric_inputs)
|
||||||
|
.Union();
|
||||||
|
SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
|
||||||
|
BuiltinOptions_FullyConnectedOptions, options);
|
||||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||||
BuiltinOperator_FULLY_CONNECTED,
|
BuiltinOperator_FULLY_CONNECTED,
|
||||||
ops::builtin::Register_FULLY_CONNECTED_PIE());
|
ops::builtin::Register_FULLY_CONNECTED_PIE());
|
||||||
@ -867,6 +871,66 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
|
|||||||
/*max_abs_error=*/1.3f)));
|
/*max_abs_error=*/1.3f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
|
||||||
|
HybridFullyConnectedOpModel m(
|
||||||
|
/*units=*/3, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 10}},
|
||||||
|
/*weights=*/
|
||||||
|
{TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
|
||||||
|
/*asymmetric_quantize_input*/ true); // Hybrid asymmetric
|
||||||
|
|
||||||
|
m.SetWeights({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
||||||
|
});
|
||||||
|
m.SetBias({1, 2, 3});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
24, 25, 26, //
|
||||||
|
58, 59, 60, //
|
||||||
|
},
|
||||||
|
/*max_abs_error=*/0.64f)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
|
||||||
|
HybridFullyConnectedOpModel m(
|
||||||
|
/*units=*/3, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 10}},
|
||||||
|
/*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
|
||||||
|
{TensorType_FLOAT32},
|
||||||
|
/*asymmetric_quantize_input*/ true);
|
||||||
|
|
||||||
|
m.SetSignedWeights({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
||||||
|
});
|
||||||
|
m.SetBias({1, 2, 3});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
24, 25, 26, //
|
||||||
|
58, 59, 60, //
|
||||||
|
},
|
||||||
|
/*max_abs_error=*/1.3f)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
|
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
|
||||||
// Note that it is not required that the first dimension be the number of
|
// Note that it is not required that the first dimension be the number of
|
||||||
// batches. All we care is that the input can be evenly distributed in
|
// batches. All we care is that the input can be evenly distributed in
|
||||||
|
@ -123,7 +123,9 @@ void RnnBatchStep(
|
|||||||
int num_units, int batch_size, int output_batch_leading_dim,
|
int num_units, int batch_size, int output_batch_leading_dim,
|
||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch) {
|
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
||||||
|
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
||||||
|
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
|
||||||
RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
|
RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
|
||||||
/*aux_input_ptr_batch=*/nullptr,
|
/*aux_input_ptr_batch=*/nullptr,
|
||||||
/*aux_input_weights_ptr=*/nullptr,
|
/*aux_input_weights_ptr=*/nullptr,
|
||||||
@ -133,7 +135,29 @@ void RnnBatchStep(
|
|||||||
output_batch_leading_dim, activation, quantized_input_ptr_batch,
|
output_batch_leading_dim, activation, quantized_input_ptr_batch,
|
||||||
/*aux_quantized_input_ptr_batch=*/nullptr,
|
/*aux_quantized_input_ptr_batch=*/nullptr,
|
||||||
quantized_hidden_state_ptr_batch, scaling_factors,
|
quantized_hidden_state_ptr_batch, scaling_factors,
|
||||||
hidden_state_ptr_batch, output_ptr_batch);
|
hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums,
|
||||||
|
compute_row_sums);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums,
|
||||||
|
int32_t* recurrent_row_sums, int32_t* row_sums,
|
||||||
|
const float* aux_input_ptr_batch, int num_units,
|
||||||
|
int input_size, int aux_input_size,
|
||||||
|
const int8_t* input_weights_ptr,
|
||||||
|
const int8_t* aux_input_weights_ptr,
|
||||||
|
const int8_t* recurrent_weights_ptr) {
|
||||||
|
memset(input_row_sums, 0, sizeof(int32_t) * num_units);
|
||||||
|
tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units,
|
||||||
|
input_size);
|
||||||
|
if (aux_input_ptr_batch) {
|
||||||
|
memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units);
|
||||||
|
tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums,
|
||||||
|
num_units, aux_input_size);
|
||||||
|
}
|
||||||
|
memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units);
|
||||||
|
tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums,
|
||||||
|
num_units, num_units);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RnnBatchStep(
|
void RnnBatchStep(
|
||||||
@ -146,9 +170,31 @@ void RnnBatchStep(
|
|||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* aux_quantized_input_ptr_batch,
|
int8_t* aux_quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch) {
|
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
||||||
|
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
||||||
|
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
|
||||||
// Since the output batch rows may not be contiguous (output_batch_leading_dim
|
// Since the output batch rows may not be contiguous (output_batch_leading_dim
|
||||||
// != n_output), we unroll the batched operations where this is the case.
|
// != n_output), we unroll the batched operations where this is the case.
|
||||||
|
|
||||||
|
int32_t* input_row_sums = nullptr;
|
||||||
|
int32_t* aux_input_row_sums = nullptr;
|
||||||
|
int32_t* recurrent_row_sums = nullptr;
|
||||||
|
if (asymmetric_quantize_inputs) {
|
||||||
|
input_row_sums = row_sums;
|
||||||
|
aux_input_row_sums = row_sums;
|
||||||
|
if (aux_input_ptr_batch) {
|
||||||
|
aux_input_row_sums += num_units;
|
||||||
|
}
|
||||||
|
recurrent_row_sums = aux_input_row_sums + num_units;
|
||||||
|
if (*compute_row_sums) {
|
||||||
|
ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums,
|
||||||
|
row_sums, aux_input_ptr_batch, num_units, input_size,
|
||||||
|
aux_input_size, input_weights_ptr,
|
||||||
|
aux_input_weights_ptr, recurrent_weights_ptr);
|
||||||
|
*compute_row_sums = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (output_batch_leading_dim == num_units) {
|
if (output_batch_leading_dim == num_units) {
|
||||||
// Output = bias
|
// Output = bias
|
||||||
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
|
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
|
||||||
@ -163,17 +209,25 @@ void RnnBatchStep(
|
|||||||
// whichever is faster.
|
// whichever is faster.
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (asymmetric_quantize_inputs) {
|
||||||
input_ptr_batch + offset, input_size,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
input_ptr_batch + offset, input_size,
|
||||||
&scaling_factors[b]);
|
quantized_input_ptr_batch + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
input_ptr_batch + offset, input_size,
|
||||||
|
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
scaling_factors[b] *= input_weights_scale;
|
scaling_factors[b] *= input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output += input * input_weights
|
// Output += input * input_weights
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
|
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
|
||||||
scaling_factors, batch_size, output_ptr_batch);
|
scaling_factors, batch_size, output_ptr_batch,
|
||||||
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch,
|
||||||
|
input_row_sums, compute_row_sums, /*context=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (aux_input_ptr_batch &&
|
if (aux_input_ptr_batch &&
|
||||||
@ -182,10 +236,17 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * aux_input_size;
|
const int offset = b * aux_input_size;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (asymmetric_quantize_inputs) {
|
||||||
aux_input_ptr_batch + offset, aux_input_size,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
aux_input_ptr_batch + offset, aux_input_size,
|
||||||
&scaling_factors[b]);
|
aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
aux_input_ptr_batch + offset, aux_input_size,
|
||||||
|
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
scaling_factors[b] *= aux_input_weights_scale;
|
scaling_factors[b] *= aux_input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +254,9 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_weights_ptr, num_units, aux_input_size,
|
aux_input_weights_ptr, num_units, aux_input_size,
|
||||||
aux_quantized_input_ptr_batch, scaling_factors, batch_size,
|
aux_quantized_input_ptr_batch, scaling_factors, batch_size,
|
||||||
output_ptr_batch);
|
output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch, aux_input_row_sums, compute_row_sums,
|
||||||
|
/*context=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
@ -203,10 +266,17 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * num_units;
|
const int offset = b * num_units;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (asymmetric_quantize_inputs) {
|
||||||
hidden_state_ptr_batch + offset, num_units,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
|
hidden_state_ptr_batch + offset, num_units,
|
||||||
&scaling_factors[b]);
|
quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
hidden_state_ptr_batch + offset, num_units,
|
||||||
|
quantized_hidden_state_ptr_batch + offset, &unused_min,
|
||||||
|
&unused_max, &scaling_factors[b]);
|
||||||
|
}
|
||||||
scaling_factors[b] *= recurrent_weights_scale;
|
scaling_factors[b] *= recurrent_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,7 +284,9 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_weights_ptr, num_units, num_units,
|
recurrent_weights_ptr, num_units, num_units,
|
||||||
quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
|
quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
|
||||||
output_ptr_batch);
|
output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch, recurrent_row_sums, compute_row_sums,
|
||||||
|
/*context=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output = activation(Output) and update hidden_state
|
// Output = activation(Output) and update hidden_state
|
||||||
@ -238,10 +310,17 @@ void RnnBatchStep(
|
|||||||
// whichever is faster.
|
// whichever is faster.
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (asymmetric_quantize_inputs) {
|
||||||
input_ptr_batch + offset, input_size,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
input_ptr_batch + offset, input_size,
|
||||||
&scaling_factors[b]);
|
quantized_input_ptr_batch + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
input_ptr_batch + offset, input_size,
|
||||||
|
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
scaling_factors[b] *= input_weights_scale;
|
scaling_factors[b] *= input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,7 +329,9 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_weights_ptr, num_units, input_size,
|
input_weights_ptr, num_units, input_size,
|
||||||
quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
|
quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
|
||||||
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
|
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
|
||||||
|
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
|
||||||
|
input_row_sums, compute_row_sums, /*context=*/nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,10 +341,17 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * aux_input_size;
|
const int offset = b * aux_input_size;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (asymmetric_quantize_inputs) {
|
||||||
aux_input_ptr_batch + offset, aux_input_size,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
aux_input_ptr_batch + offset, aux_input_size,
|
||||||
&scaling_factors[b]);
|
aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
aux_input_ptr_batch + offset, aux_input_size,
|
||||||
|
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
scaling_factors[b] *= aux_input_weights_scale;
|
scaling_factors[b] *= aux_input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,7 +361,9 @@ void RnnBatchStep(
|
|||||||
aux_input_weights_ptr, num_units, aux_input_size,
|
aux_input_weights_ptr, num_units, aux_input_size,
|
||||||
aux_quantized_input_ptr_batch + k * aux_input_size,
|
aux_quantized_input_ptr_batch + k * aux_input_size,
|
||||||
&scaling_factors[k],
|
&scaling_factors[k],
|
||||||
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
|
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
|
||||||
|
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
|
||||||
|
aux_input_row_sums, compute_row_sums, /*context=*/nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,10 +374,17 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * num_units;
|
const int offset = b * num_units;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (asymmetric_quantize_inputs) {
|
||||||
hidden_state_ptr_batch + offset, num_units,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
|
hidden_state_ptr_batch + offset, num_units,
|
||||||
&scaling_factors[b]);
|
quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
hidden_state_ptr_batch + offset, num_units,
|
||||||
|
quantized_hidden_state_ptr_batch + offset, &unused_min,
|
||||||
|
&unused_max, &scaling_factors[b]);
|
||||||
|
}
|
||||||
scaling_factors[b] *= recurrent_weights_scale;
|
scaling_factors[b] *= recurrent_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,8 +393,10 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_weights_ptr, num_units, num_units,
|
recurrent_weights_ptr, num_units, num_units,
|
||||||
quantized_hidden_state_ptr_batch + k * num_units,
|
quantized_hidden_state_ptr_batch + k * num_units,
|
||||||
&scaling_factors[k],
|
&scaling_factors[k], /*n_batch=*/1,
|
||||||
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
|
output_ptr_batch + k * output_batch_leading_dim,
|
||||||
|
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
|
||||||
|
recurrent_row_sums, compute_row_sums, /*context=*/nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +70,9 @@ void RnnBatchStep(
|
|||||||
int num_units, int batch_size, int output_batch_leading_dim,
|
int num_units, int batch_size, int output_batch_leading_dim,
|
||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch);
|
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
||||||
|
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
||||||
|
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
|
||||||
|
|
||||||
void RnnBatchStep(
|
void RnnBatchStep(
|
||||||
const float* input_ptr_batch, const int8_t* input_weights_ptr,
|
const float* input_ptr_batch, const int8_t* input_weights_ptr,
|
||||||
@ -82,7 +84,9 @@ void RnnBatchStep(
|
|||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* aux_quantized_input_ptr_batch,
|
int8_t* aux_quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch);
|
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
||||||
|
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
||||||
|
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
|
||||||
|
|
||||||
} // namespace kernel_utils
|
} // namespace kernel_utils
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -1310,6 +1310,13 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
|
const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
|
||||||
const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
|
const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
|
||||||
|
|
||||||
|
int32_t* row_sums_ptr = row_sums;
|
||||||
|
if (row_sums == nullptr) {
|
||||||
|
row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
|
||||||
|
memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows);
|
||||||
|
NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
|
||||||
|
}
|
||||||
|
|
||||||
for (int batch = 0; batch < n_batch; ++batch) {
|
for (int batch = 0; batch < n_batch; ++batch) {
|
||||||
const float batch_scaling_factor = scaling_factors[batch];
|
const float batch_scaling_factor = scaling_factors[batch];
|
||||||
const int batch_input_offset = input_offset[batch];
|
const int batch_input_offset = input_offset[batch];
|
||||||
@ -1327,10 +1334,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
// Initialize the dot product sum for the row to 0.
|
// Initialize the dot product sum for the row to 0.
|
||||||
int32x4_t dotprod_32x4 = vmovq_n_s32(0);
|
int32x4_t dotprod_32x4 = vmovq_n_s32(0);
|
||||||
|
|
||||||
int32x4_t row_sum_32x4;
|
|
||||||
if (row_sums == nullptr) {
|
|
||||||
row_sum_32x4 = vmovq_n_s32(0);
|
|
||||||
}
|
|
||||||
// Prefetch the row to cache.
|
// Prefetch the row to cache.
|
||||||
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
||||||
3 /* temporal locality */);
|
3 /* temporal locality */);
|
||||||
@ -1358,10 +1361,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
prod_16x8 =
|
prod_16x8 =
|
||||||
vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
|
vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
|
||||||
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
||||||
if (row_sums == nullptr) {
|
|
||||||
const int16x8_t row_sum_16x8 = vpaddlq_s8(s2_8x16);
|
|
||||||
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
|
|
||||||
}
|
|
||||||
} // for col
|
} // for col
|
||||||
|
|
||||||
// Half iteration dealing only 8 elements
|
// Half iteration dealing only 8 elements
|
||||||
@ -1375,29 +1374,24 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
|
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
|
||||||
const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
|
const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
|
||||||
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
||||||
if (row_sums == nullptr) {
|
|
||||||
const int16x8_t row_sum_16x8 = vmovl_s8(s2_8x8);
|
|
||||||
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
|
|
||||||
}
|
|
||||||
col += (kWeightsPerNeonLane >> 1);
|
col += (kWeightsPerNeonLane >> 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
|
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
|
||||||
int32_t row_sum = row_sums == nullptr ? AccumulateNeonLane(row_sum_32x4)
|
|
||||||
: row_sums[row];
|
|
||||||
|
|
||||||
// Postamble loop.
|
// Postamble loop.
|
||||||
for (; col < m_cols; ++col) {
|
for (; col < m_cols; ++col) {
|
||||||
dotprod += row_ptr[col] * aligned_vec[col];
|
dotprod += row_ptr[col] * aligned_vec[col];
|
||||||
if (row_sums == nullptr) {
|
|
||||||
row_sum += row_ptr[col];
|
|
||||||
}
|
|
||||||
} // for col
|
} // for col
|
||||||
dotprod -= row_sum * batch_input_offset;
|
dotprod -= row_sums_ptr[row] * batch_input_offset;
|
||||||
*result += dotprod * scale;
|
*result += dotprod * scale;
|
||||||
++result;
|
++result;
|
||||||
} // for row
|
} // for row
|
||||||
} // for batch
|
} // for batch
|
||||||
|
|
||||||
|
if (row_sums == nullptr) {
|
||||||
|
free(row_sums_ptr);
|
||||||
|
}
|
||||||
if (unaligned) {
|
if (unaligned) {
|
||||||
free(aligned_row_free);
|
free(aligned_row_free);
|
||||||
}
|
}
|
||||||
@ -1410,6 +1404,20 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
|||||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||||
bool* compute_row_sums, CpuBackendContext* context) {
|
bool* compute_row_sums, CpuBackendContext* context) {
|
||||||
|
if (input_offset == nullptr) {
|
||||||
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
|
if (context) {
|
||||||
|
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||||
|
scaling_factors, n_batch, scratch,
|
||||||
|
result, context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||||
|
scaling_factors, n_batch, result);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (compute_row_sums == nullptr || *compute_row_sums) {
|
if (compute_row_sums == nullptr || *compute_row_sums) {
|
||||||
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
||||||
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
||||||
@ -1419,7 +1427,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
if (m_rows % 4 == 0) {
|
if (context != nullptr && m_rows % 4 == 0) {
|
||||||
const int32_t* bias = static_cast<const int32_t*>(nullptr);
|
const int32_t* bias = static_cast<const int32_t*>(nullptr);
|
||||||
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
|
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
|
||||||
scratch, context);
|
scratch, context);
|
||||||
@ -1463,9 +1471,9 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
|||||||
for (; i < total_size; i++) {
|
for (; i < total_size; i++) {
|
||||||
const float batch_scaling_factor = scaling_factors[i / m_rows];
|
const float batch_scaling_factor = scaling_factors[i / m_rows];
|
||||||
const int32_t zero_point = input_offset[i / m_rows];
|
const int32_t zero_point = input_offset[i / m_rows];
|
||||||
int32_t x = *(scratch_ptr++);
|
int32_t dotprod = *(scratch_ptr++);
|
||||||
x -= row_sums[i % m_rows] * zero_point;
|
dotprod -= row_sums[i % m_rows] * zero_point;
|
||||||
*result += x * batch_scaling_factor;
|
*result += dotprod * batch_scaling_factor;
|
||||||
++result;
|
++result;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
@ -167,6 +167,11 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
|||||||
const float* __restrict__ scaling_factors, int n_batch,
|
const float* __restrict__ scaling_factors, int n_batch,
|
||||||
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
||||||
const int32_t* __restrict__ input_offset) {
|
const int32_t* __restrict__ input_offset) {
|
||||||
|
if (input_offset == nullptr) {
|
||||||
|
SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||||
|
scaling_factors, n_batch, result);
|
||||||
|
return;
|
||||||
|
}
|
||||||
static constexpr std::intptr_t kBlockSize = 16;
|
static constexpr std::intptr_t kBlockSize = 16;
|
||||||
for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
|
for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
|
||||||
const float batch_scaling_factor = scaling_factors[batch];
|
const float batch_scaling_factor = scaling_factors[batch];
|
||||||
|
@ -59,9 +59,10 @@ void MatrixBatchVectorMultiplyAccumulate(
|
|||||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||||
bool* compute_row_sums, CpuBackendContext* context) {
|
bool* compute_row_sums, CpuBackendContext* context) {
|
||||||
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||||
vectors, scaling_factors, n_batch, result, per_channel_scale,
|
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||||
input_offset, scratch, row_sums, compute_row_sums, context);
|
per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MatrixBatchVectorMultiplyAccumulate(
|
void MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
@ -196,6 +196,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
|
|||||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||||
bool* compute_row_sums, CpuBackendContext* context) {
|
bool* compute_row_sums, CpuBackendContext* context) {
|
||||||
|
if (input_offset == nullptr) {
|
||||||
|
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (!compute_row_sums || *compute_row_sums) {
|
if (!compute_row_sums || *compute_row_sums) {
|
||||||
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
||||||
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
||||||
|
@ -223,7 +223,8 @@ inline void EvalHybridSVDF(
|
|||||||
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
|
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
|
||||||
const TfLiteTensor* bias, const TfLiteSVDFParams* params,
|
const TfLiteTensor* bias, const TfLiteSVDFParams* params,
|
||||||
TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
|
TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
|
||||||
TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
|
TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output,
|
||||||
|
TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) {
|
||||||
const int rank = params->rank;
|
const int rank = params->rank;
|
||||||
const int batch_size = input->dims->data[0];
|
const int batch_size = input->dims->data[0];
|
||||||
const int input_size = input->dims->data[1];
|
const int input_size = input->dims->data[1];
|
||||||
@ -244,6 +245,13 @@ inline void EvalHybridSVDF(
|
|||||||
|
|
||||||
float* output_ptr = GetTensorData<float>(output);
|
float* output_ptr = GetTensorData<float>(output);
|
||||||
|
|
||||||
|
int32_t* zero_points_ptr = nullptr;
|
||||||
|
int32_t* row_sums_ptr = nullptr;
|
||||||
|
if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
|
||||||
|
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
||||||
|
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the weights scale.
|
// Initialize the weights scale.
|
||||||
const float weights_feature_scale = weights_feature->params.scale;
|
const float weights_feature_scale = weights_feature->params.scale;
|
||||||
|
|
||||||
@ -258,21 +266,30 @@ inline void EvalHybridSVDF(
|
|||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
|
if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
|
||||||
// Quantize input from float to int8.
|
// Quantize input from float to int8.
|
||||||
float unused_min, unused_max;
|
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
if (params->asymmetric_quantize_inputs) {
|
||||||
input_ptr + offset, input_size, quantized_input_ptr + offset,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
&unused_min, &unused_max, &scaling_factors_ptr[b]);
|
input_ptr + offset, input_size, quantized_input_ptr + offset,
|
||||||
|
&scaling_factors_ptr[b], &zero_points_ptr[b]);
|
||||||
|
} else {
|
||||||
|
// Quantize input from float to int8.
|
||||||
|
float unused_min, unused_max;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
input_ptr + offset, input_size, quantized_input_ptr + offset,
|
||||||
|
&unused_min, &unused_max, &scaling_factors_ptr[b]);
|
||||||
|
}
|
||||||
scaling_factors_ptr[b] *= weights_feature_scale;
|
scaling_factors_ptr[b] *= weights_feature_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute conv1d(inputs, weights_feature).
|
// Compute conv1d(inputs, weights_feature).
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
|
weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
|
||||||
scaling_factors_ptr, batch_size, scratch_ptr);
|
scaling_factors_ptr, batch_size, scratch_ptr,
|
||||||
|
/*per_channel_scale=*/nullptr, zero_points_ptr,
|
||||||
|
reinterpret_cast<int32_t*>(scratch_ptr), row_sums_ptr, compute_row_sums,
|
||||||
|
/*context=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the latest activation from scratch into activation_state:
|
// Copy the latest activation from scratch into activation_state:
|
||||||
// The last, i.e. (memory_size-1)th entry for each batch, and filter.
|
// The last, i.e. (memory_size-1)th entry for each batch, and filter.
|
||||||
for (int i = 0; i < batch_size * num_filters; ++i) {
|
for (int i = 0; i < batch_size * num_filters; ++i) {
|
||||||
|
@ -55,6 +55,7 @@ struct OpData {
|
|||||||
// These fields are only used by full kernel.
|
// These fields are only used by full kernel.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
||||||
|
bool compute_row_sums;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace full {
|
namespace full {
|
||||||
@ -727,7 +728,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
|||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
op_data->kernel_type = kTfLiteLSTMFullKernel;
|
op_data->kernel_type = kTfLiteLSTMFullKernel;
|
||||||
context->AddTensors(context, /*tensors_to_add=*/8,
|
context->AddTensors(context, /*tensors_to_add=*/10,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -1236,7 +1237,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(8);
|
node->temporaries = TfLiteIntArrayCreate(10);
|
||||||
} else if (is_integer) {
|
} else if (is_integer) {
|
||||||
if (is_8x8_16) {
|
if (is_8x8_16) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(6);
|
node->temporaries = TfLiteIntArrayCreate(6);
|
||||||
@ -1273,6 +1274,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
|
op_data->compute_row_sums = true;
|
||||||
// Allocate temporary tensors to store quantized values of input,
|
// Allocate temporary tensors to store quantized values of input,
|
||||||
// activation_state and cell_state tensors.
|
// activation_state and cell_state tensors.
|
||||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||||
@ -1370,6 +1372,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node->temporaries->data[8] = op_data->scratch_tensor_index + 8;
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
|
||||||
|
zero_points->type = kTfLiteFloat32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
int zero_points_dims[1] = {n_batch};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = n_batch;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
node->temporaries->data[9] = op_data->scratch_tensor_index + 9;
|
||||||
|
const TfLiteTensor* input_to_input_weights =
|
||||||
|
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||||
|
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||||
|
int row_sums_rows = use_cifg ? 6 : 8;
|
||||||
|
const TfLiteTensor* projection_weights =
|
||||||
|
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
||||||
|
if (projection_weights != nullptr) {
|
||||||
|
row_sums_rows += ceil(n_output / n_cell);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
|
||||||
|
row_sums->type = kTfLiteInt32;
|
||||||
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
const int row_sums_dims[2] = {row_sums_rows, n_cell};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
||||||
|
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
row_sums_size->data[0] = row_sums_dims[0];
|
||||||
|
row_sums_size->data[1] = row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_integer) {
|
if (is_integer) {
|
||||||
@ -1556,6 +1593,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, /*index=*/6);
|
GetTemporary(context, node, /*index=*/6);
|
||||||
TfLiteTensor* output_scratch_buffer =
|
TfLiteTensor* output_scratch_buffer =
|
||||||
GetTemporary(context, node, /*index=*/7);
|
GetTemporary(context, node, /*index=*/7);
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
|
||||||
|
const int row_sums_size = row_sums->dims->data[0];
|
||||||
return lstm_eval::EvalHybrid(
|
return lstm_eval::EvalHybrid(
|
||||||
input, input_to_input_weights, input_to_forget_weights,
|
input, input_to_input_weights, input_to_forget_weights,
|
||||||
input_to_cell_weights, input_to_output_weights,
|
input_to_cell_weights, input_to_output_weights,
|
||||||
@ -1577,7 +1617,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
input_quantized,
|
input_quantized,
|
||||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||||
cell_state_quantized, activation_state, cell_state,
|
cell_state_quantized, activation_state, cell_state,
|
||||||
output_scratch_buffer, output,
|
output_scratch_buffer, output, zero_points, row_sums, row_sums_size,
|
||||||
|
&op_data->compute_row_sums,
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
} else {
|
} else {
|
||||||
const int num_intermediate_tensors = node->intermediates->size;
|
const int num_intermediate_tensors = node->intermediates->size;
|
||||||
|
@ -33,24 +33,93 @@ namespace builtin {
|
|||||||
namespace lstm_eval {
|
namespace lstm_eval {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline float GetTensorScale(const TfLiteTensor* tensor) {
|
void ComputeRowSums(
|
||||||
return tensor == nullptr ? 1.0f : tensor->params.scale;
|
int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
|
||||||
|
int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
|
||||||
|
int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
|
||||||
|
int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
|
||||||
|
int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
|
||||||
|
int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
|
||||||
|
int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
|
||||||
|
int n_input, int n_aux_input, int n_output,
|
||||||
|
const int8_t* input_to_input_weights_ptr,
|
||||||
|
const int8_t* input_to_forget_weights_ptr,
|
||||||
|
const int8_t* input_to_cell_weights_ptr,
|
||||||
|
const int8_t* input_to_output_weights_ptr,
|
||||||
|
const int8_t* aux_input_to_input_weights_ptr,
|
||||||
|
const int8_t* aux_input_to_forget_weights_ptr,
|
||||||
|
const int8_t* aux_input_to_cell_weights_ptr,
|
||||||
|
const int8_t* aux_input_to_output_weights_ptr,
|
||||||
|
const int8_t* recurrent_to_input_weights_ptr,
|
||||||
|
const int8_t* recurrent_to_forget_weights_ptr,
|
||||||
|
const int8_t* recurrent_to_cell_weights_ptr,
|
||||||
|
const int8_t* recurrent_to_output_weights_ptr,
|
||||||
|
const int8_t* projection_weights_ptr, bool use_cifg,
|
||||||
|
const float* aux_input_ptr) {
|
||||||
|
// Compute the row sums for dequantization
|
||||||
|
if (!use_cifg) {
|
||||||
|
memset(input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
|
||||||
|
input_to_input_row_sums, n_cell, n_input);
|
||||||
|
}
|
||||||
|
memset(input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
|
||||||
|
input_to_forget_row_sums, n_cell, n_input);
|
||||||
|
memset(input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
|
||||||
|
input_to_cell_row_sums, n_cell, n_input);
|
||||||
|
memset(input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
|
||||||
|
input_to_output_row_sums, n_cell, n_input);
|
||||||
|
|
||||||
|
if (aux_input_ptr) {
|
||||||
|
if (!use_cifg) {
|
||||||
|
memset(aux_input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
|
||||||
|
aux_input_to_input_row_sums, n_cell,
|
||||||
|
n_aux_input);
|
||||||
|
}
|
||||||
|
memset(aux_input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
|
||||||
|
aux_input_to_forget_row_sums, n_cell,
|
||||||
|
n_aux_input);
|
||||||
|
memset(aux_input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
|
||||||
|
aux_input_to_cell_row_sums, n_cell,
|
||||||
|
n_aux_input);
|
||||||
|
memset(aux_input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
|
||||||
|
aux_input_to_output_row_sums, n_cell,
|
||||||
|
n_aux_input);
|
||||||
|
}
|
||||||
|
if (!use_cifg) {
|
||||||
|
memset(recurrent_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
|
||||||
|
recurrent_to_input_row_sums, n_cell,
|
||||||
|
n_output);
|
||||||
|
}
|
||||||
|
memset(recurrent_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
|
||||||
|
recurrent_to_forget_row_sums, n_cell,
|
||||||
|
n_output);
|
||||||
|
memset(recurrent_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
|
||||||
|
recurrent_to_cell_row_sums, n_cell,
|
||||||
|
n_output);
|
||||||
|
memset(recurrent_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
|
||||||
|
tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
|
||||||
|
recurrent_to_output_row_sums, n_cell,
|
||||||
|
n_output);
|
||||||
|
|
||||||
|
if (projection_weights_ptr != nullptr) {
|
||||||
|
memset(projection_weights_row_sums, 0, sizeof(int32_t) * n_output);
|
||||||
|
tensor_utils::ReductionSumVector(
|
||||||
|
projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
inline float GetTensorScale(const TfLiteTensor* tensor) {
|
||||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
return tensor == nullptr ? 1.0f : tensor->params.scale;
|
||||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
|
||||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
|
||||||
CpuBackendContext* context) {
|
|
||||||
// TODO(b/148289189) Remove when Ruy GEMV is the default.
|
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch,
|
|
||||||
result, context);
|
|
||||||
#else
|
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Performs an LSTM batch inference step for input specified by input_ptr.
|
// Performs an LSTM batch inference step for input specified by input_ptr.
|
||||||
@ -473,6 +542,8 @@ inline void LstmStepHybrid(
|
|||||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
||||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
||||||
|
int32_t* zero_points, int32_t* row_sums, int row_sums_size,
|
||||||
|
bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||||
// Since we have already checked that weights are all there or none, we
|
// Since we have already checked that weights are all there or none, we
|
||||||
@ -503,53 +574,131 @@ inline void LstmStepHybrid(
|
|||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each batch and cell: compute input_weight * input.
|
int32_t* input_to_input_row_sums = nullptr;
|
||||||
// Skip if input is all zeros.
|
int32_t* input_to_forget_row_sums = nullptr;
|
||||||
|
int32_t* input_to_cell_row_sums = nullptr;
|
||||||
|
int32_t* input_to_output_row_sums = nullptr;
|
||||||
|
int32_t* aux_input_to_input_row_sums = nullptr;
|
||||||
|
int32_t* aux_input_to_forget_row_sums = nullptr;
|
||||||
|
int32_t* aux_input_to_cell_row_sums = nullptr;
|
||||||
|
int32_t* aux_input_to_output_row_sums = nullptr;
|
||||||
|
int32_t* recurrent_to_input_row_sums = nullptr;
|
||||||
|
int32_t* recurrent_to_forget_row_sums = nullptr;
|
||||||
|
int32_t* recurrent_to_cell_row_sums = nullptr;
|
||||||
|
int32_t* recurrent_to_output_row_sums = nullptr;
|
||||||
|
int32_t* projection_weights_row_sums = nullptr;
|
||||||
|
|
||||||
|
if (asymmetric_quantize_inputs) {
|
||||||
|
int num_row_sums = use_cifg ? 6 : 8;
|
||||||
|
if (aux_input_ptr != nullptr) {
|
||||||
|
num_row_sums += use_cifg ? 3 : 4;
|
||||||
|
}
|
||||||
|
if (projection_weights_ptr != nullptr) {
|
||||||
|
num_row_sums += ceil(n_output / n_cell);
|
||||||
|
}
|
||||||
|
TF_LITE_ASSERT(row_sums_size == num_row_sums);
|
||||||
|
input_to_input_row_sums = row_sums;
|
||||||
|
input_to_forget_row_sums =
|
||||||
|
use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
|
||||||
|
input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
|
||||||
|
input_to_output_row_sums = input_to_cell_row_sums + n_cell;
|
||||||
|
if (aux_input_ptr != nullptr) {
|
||||||
|
aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
|
||||||
|
aux_input_to_forget_row_sums = use_cifg
|
||||||
|
? aux_input_to_input_row_sums
|
||||||
|
: aux_input_to_input_row_sums + n_cell;
|
||||||
|
aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
|
||||||
|
aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
|
||||||
|
}
|
||||||
|
recurrent_to_input_row_sums = aux_input_ptr
|
||||||
|
? aux_input_to_output_row_sums + n_cell
|
||||||
|
: input_to_output_row_sums + n_cell;
|
||||||
|
recurrent_to_forget_row_sums = use_cifg
|
||||||
|
? recurrent_to_input_row_sums
|
||||||
|
: recurrent_to_input_row_sums + n_cell;
|
||||||
|
recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
|
||||||
|
recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
|
||||||
|
if (projection_weights_ptr != nullptr) {
|
||||||
|
projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
|
||||||
|
}
|
||||||
|
if (*compute_row_sums) {
|
||||||
|
ComputeRowSums(
|
||||||
|
input_to_input_row_sums, input_to_forget_row_sums,
|
||||||
|
input_to_cell_row_sums, input_to_output_row_sums,
|
||||||
|
aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
|
||||||
|
aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
|
||||||
|
recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
|
||||||
|
recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
|
||||||
|
projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
|
||||||
|
n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
|
||||||
|
input_to_cell_weights_ptr, input_to_output_weights_ptr,
|
||||||
|
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
|
||||||
|
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
|
||||||
|
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
|
||||||
|
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
|
||||||
|
projection_weights_ptr, use_cifg, aux_input_ptr);
|
||||||
|
*compute_row_sums = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
|
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_input;
|
const int offset = b * n_input;
|
||||||
float unused_min, unused_max;
|
if (asymmetric_quantize_inputs) {
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
input_ptr + offset, n_input, quantized_input_ptr + offset,
|
input_ptr + offset, n_input, quantized_input_ptr + offset,
|
||||||
&unused_min, &unused_max, &scaling_factors[b]);
|
&scaling_factors[b], &zero_points[b]);
|
||||||
|
} else {
|
||||||
|
float unused_min, unused_max;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
input_ptr + offset, n_input, quantized_input_ptr + offset,
|
||||||
|
&unused_min, &unused_max, &scaling_factors[b]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_input_weights_scale;
|
scaling_factors[b] * input_to_input_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, accum_scratch_ptr,
|
product_scaling_factors, n_batch, input_gate_scratch,
|
||||||
input_gate_scratch, context);
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
|
input_to_input_row_sums, compute_row_sums, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_forget_weights_scale;
|
scaling_factors[b] * input_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, accum_scratch_ptr,
|
product_scaling_factors, n_batch, forget_gate_scratch,
|
||||||
forget_gate_scratch, context);
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
|
input_to_forget_row_sums, compute_row_sums, context);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_cell_weights_scale;
|
scaling_factors[b] * input_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch,
|
product_scaling_factors, n_batch, cell_scratch,
|
||||||
context);
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
|
input_to_cell_row_sums, compute_row_sums, context);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_output_weights_scale;
|
scaling_factors[b] * input_to_output_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, accum_scratch_ptr,
|
product_scaling_factors, n_batch, output_gate_scratch,
|
||||||
output_gate_scratch, context);
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
|
input_to_output_row_sums, compute_row_sums, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||||
@ -558,59 +707,84 @@ inline void LstmStepHybrid(
|
|||||||
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
|
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_aux_input;
|
const int offset = b * n_aux_input;
|
||||||
float unused_min, unused_max;
|
if (asymmetric_quantize_inputs) {
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
aux_input_ptr + offset, n_aux_input, quantized_aux_input_ptr + offset,
|
aux_input_ptr + offset, n_aux_input,
|
||||||
&unused_min, &unused_max, &scaling_factors[b]);
|
quantized_aux_input_ptr + offset, &scaling_factors[b],
|
||||||
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
float unused_min, unused_max;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
aux_input_ptr + offset, n_aux_input,
|
||||||
|
quantized_aux_input_ptr + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_input_weights_scale;
|
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, input_gate_scratch, context);
|
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, forget_gate_scratch, context);
|
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
|
row_sums += n_cell;
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
|
||||||
accum_scratch_ptr, cell_scratch, context);
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
|
aux_input_to_cell_row_sums, compute_row_sums, context);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_output_weights_scale;
|
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, output_gate_scratch, context);
|
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_output;
|
const int offset = b * n_output;
|
||||||
float unused_min, unused_max;
|
if (asymmetric_quantize_inputs) {
|
||||||
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
quantized_output_state_ptr + offset,
|
output_state_ptr + offset, n_output,
|
||||||
&unused_min, &unused_max,
|
quantized_output_state_ptr + offset, &scaling_factors[b],
|
||||||
&scaling_factors[b]);
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
float unused_min, unused_max;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
output_state_ptr + offset, n_output,
|
||||||
|
quantized_output_state_ptr + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// For each batch and cell: compute recurrent_weight * output_state.
|
// For each batch and cell: compute recurrent_weight * output_state.
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
@ -618,38 +792,46 @@ inline void LstmStepHybrid(
|
|||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, input_gate_scratch, context);
|
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, forget_gate_scratch, context);
|
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, cell_scratch, context);
|
cell_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||||
}
|
}
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
accum_scratch_ptr, output_gate_scratch, context);
|
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||||
|
accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each batch and cell: update input gate.
|
// For each batch and cell: update input gate.
|
||||||
@ -770,22 +952,32 @@ inline void LstmStepHybrid(
|
|||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_cell;
|
const int offset = b * n_cell;
|
||||||
float unused_min, unused_max;
|
if (asymmetric_quantize_inputs) {
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::AsymmetricQuantizeFloats(
|
||||||
output_gate_scratch + offset, n_cell,
|
output_gate_scratch + offset, n_cell,
|
||||||
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
quantized_cell_state_ptr + offset, &scaling_factors[b],
|
||||||
&scaling_factors[b]);
|
&zero_points[b]);
|
||||||
|
} else {
|
||||||
|
float unused_min, unused_max;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
output_gate_scratch + offset, n_cell,
|
||||||
|
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * projection_weights_scale;
|
scaling_factors[b] * projection_weights_scale;
|
||||||
}
|
}
|
||||||
for (int b = 0; b < n_batch; b++) {
|
for (int b = 0; b < n_batch; b++) {
|
||||||
MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
projection_weights_ptr, n_output, n_cell,
|
projection_weights_ptr, n_output, n_cell,
|
||||||
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
||||||
/*n_batch=*/1, accum_scratch_ptr,
|
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
|
||||||
output_ptr + b * output_batch_leading_dim, context);
|
/*per_channel_scale=*/nullptr,
|
||||||
|
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,
|
||||||
|
accum_scratch_ptr, projection_weights_row_sums, compute_row_sums,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params->proj_clip > 0.0) {
|
if (params->proj_clip > 0.0) {
|
||||||
@ -1615,7 +1807,8 @@ TfLiteStatus EvalHybrid(
|
|||||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||||
TfLiteTensor* output, CpuBackendContext* context) {
|
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
|
||||||
|
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) {
|
||||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||||
const int n_input = input->dims->data[input->dims->size - 1];
|
const int n_input = input->dims->data[input->dims->size - 1];
|
||||||
int max_time, n_batch;
|
int max_time, n_batch;
|
||||||
@ -1654,6 +1847,14 @@ TfLiteStatus EvalHybrid(
|
|||||||
|
|
||||||
const int output_batch_leading_dim =
|
const int output_batch_leading_dim =
|
||||||
output->dims->data[output->dims->size - 1];
|
output->dims->data[output->dims->size - 1];
|
||||||
|
|
||||||
|
int32_t* zero_points_ptr = nullptr;
|
||||||
|
int32_t* row_sums_ptr = nullptr;
|
||||||
|
if (params->asymmetric_quantize_inputs) {
|
||||||
|
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
||||||
|
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||||
|
}
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
// Feed the sequence into the LSTM step-by-step.
|
// Feed the sequence into the LSTM step-by-step.
|
||||||
const int input_step = n_batch * n_input;
|
const int input_step = n_batch * n_input;
|
||||||
@ -1721,7 +1922,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<int8_t>(output_state_quantized),
|
GetTensorData<int8_t>(output_state_quantized),
|
||||||
GetTensorData<int8_t>(cell_state_quantized),
|
GetTensorData<int8_t>(cell_state_quantized),
|
||||||
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
||||||
GetTensorData<int32_t>(output_scratch_buffer), output_ptr, context);
|
GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
|
||||||
|
zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
|
||||||
|
params->asymmetric_quantize_inputs, context);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int b = 0; b < n_batch; b++) {
|
for (int b = 0; b < n_batch; b++) {
|
||||||
@ -1806,7 +2009,8 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<int8_t>(output_state_quantized),
|
GetTensorData<int8_t>(output_state_quantized),
|
||||||
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
|
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
|
||||||
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
|
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
|
||||||
output_ptr, context);
|
output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size,
|
||||||
|
compute_row_sums, params->asymmetric_quantize_inputs, context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,8 @@ TfLiteStatus EvalHybrid(
|
|||||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||||
TfLiteTensor* output, CpuBackendContext* context);
|
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
|
||||||
|
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context);
|
||||||
|
|
||||||
TfLiteStatus EvalInteger8x8_16(
|
TfLiteStatus EvalInteger8x8_16(
|
||||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||||
|
@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
bool use_peephole, bool use_projection_weights,
|
bool use_peephole, bool use_projection_weights,
|
||||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
bool use_projection_bias, float cell_clip, float proj_clip,
|
||||||
const std::vector<std::vector<int>>& input_shapes,
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
const TensorType weight_type, bool is_layer_norm)
|
const TensorType weight_type, bool is_layer_norm,
|
||||||
|
bool asymmetric_quantize_inputs = false)
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_cell_(n_cell),
|
n_cell_(n_cell),
|
||||||
@ -129,10 +130,12 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
|
|
||||||
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
SetBuiltinOp(
|
||||||
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
|
BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
||||||
cell_clip, proj_clip)
|
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
|
||||||
.Union());
|
proj_clip, ::tflite::LSTMKernelType_FULL,
|
||||||
|
asymmetric_quantize_inputs)
|
||||||
|
.Union());
|
||||||
|
|
||||||
// Do not apply delegate yet since tensor values are not known (and more
|
// Do not apply delegate yet since tensor values are not known (and more
|
||||||
// specifically scales in quantized tensors are not known).
|
// specifically scales in quantized tensors are not known).
|
||||||
@ -315,7 +318,7 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
const TensorType weight_type_;
|
const TensorType weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseLstmTest : public ::testing::Test {
|
class BaseLstmTest : public ::testing::TestWithParam<bool> {
|
||||||
protected:
|
protected:
|
||||||
// Weights of the LSTM model. Some are optional.
|
// Weights of the LSTM model. Some are optional.
|
||||||
std::vector<float> input_to_input_weights_;
|
std::vector<float> input_to_input_weights_;
|
||||||
@ -565,8 +568,11 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -604,14 +610,20 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false, GetParam());
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
class NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test
|
||||||
|
: public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
|
||||||
|
|
||||||
|
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -649,7 +661,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false, GetParam());
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
@ -745,8 +757,11 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -784,13 +799,18 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false, GetParam());
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
class CifgNoPeepholeNoProjectionNoClippingLstmInt8Test
|
||||||
|
: public CifgNoPeepholeNoProjectionNoClippingLstmTest {};
|
||||||
|
|
||||||
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -828,7 +848,7 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false, GetParam());
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
@ -1474,50 +1494,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
|
||||||
const int n_batch = 2;
|
|
||||||
const int n_input = 5;
|
|
||||||
const int n_cell = 20;
|
|
||||||
const int n_output = 16;
|
|
||||||
|
|
||||||
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
|
|
||||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
|
||||||
/*use_projection_weights=*/true,
|
|
||||||
/*use_projection_bias=*/false,
|
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
||||||
{
|
|
||||||
{n_batch, n_input}, // input tensor
|
|
||||||
|
|
||||||
{n_cell, n_input}, // input_to_input_weight tensor
|
|
||||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
||||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
||||||
{n_cell, n_input}, // input_to_output_weight tensor
|
|
||||||
|
|
||||||
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
||||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
||||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
||||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
||||||
|
|
||||||
{n_cell}, // cell_to_input_weight tensor
|
|
||||||
{n_cell}, // cell_to_forget_weight tensor
|
|
||||||
{n_cell}, // cell_to_output_weight tensor
|
|
||||||
|
|
||||||
{n_cell}, // input_gate_bias tensor
|
|
||||||
{n_cell}, // forget_gate_bias tensor
|
|
||||||
{n_cell}, // cell_bias tensor
|
|
||||||
{n_cell}, // output_gate_bias tensor
|
|
||||||
|
|
||||||
{n_output, n_cell}, // projection_weight tensor
|
|
||||||
{0}, // projection_bias tensor
|
|
||||||
},
|
|
||||||
/*weight_type=*/TensorType_INT8,
|
|
||||||
/*is_layer_norm=*/false);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1554,11 +1535,60 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false, GetParam());
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class NoCifgPeepholeProjectionNoClippingLstmInt8Test
|
||||||
|
: public NoCifgPeepholeProjectionNoClippingLstmTest {};
|
||||||
|
|
||||||
|
TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
|
||||||
|
HybridLstmBlackBoxTestInt8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int n_batch = 2;
|
||||||
|
const int n_input = 5;
|
||||||
|
const int n_cell = 20;
|
||||||
|
const int n_output = 16;
|
||||||
|
|
||||||
|
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
|
||||||
|
/*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
|
/*use_projection_weights=*/true,
|
||||||
|
/*use_projection_bias=*/false,
|
||||||
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
|
{
|
||||||
|
{n_batch, n_input}, // input tensor
|
||||||
|
|
||||||
|
{n_cell, n_input}, // input_to_input_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell}, // cell_to_input_weight tensor
|
||||||
|
{n_cell}, // cell_to_forget_weight tensor
|
||||||
|
{n_cell}, // cell_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell}, // input_gate_bias tensor
|
||||||
|
{n_cell}, // forget_gate_bias tensor
|
||||||
|
{n_cell}, // cell_bias tensor
|
||||||
|
{n_cell}, // output_gate_bias tensor
|
||||||
|
|
||||||
|
{n_output, n_cell}, // projection_weight tensor
|
||||||
|
{0}, // projection_bias tensor
|
||||||
|
},
|
||||||
|
/*weight_type=*/TensorType_INT8,
|
||||||
|
/*is_layer_norm=*/false, GetParam());
|
||||||
|
|
||||||
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015);
|
||||||
|
}
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||||
: public BaseLstmTest {
|
: public BaseLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
@ -1693,8 +1723,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
HybridLayerNormLstmBlackBoxTestUint8) {
|
HybridLayerNormLstmBlackBoxTestUint8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
@ -1741,7 +1774,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true, GetParam());
|
||||||
|
|
||||||
lstm_golden_output_ = {{
|
lstm_golden_output_ = {{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
@ -1760,8 +1793,14 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*tolerance=*/0.0010907);
|
/*tolerance=*/0.0010907);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
class NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test
|
||||||
|
: public NoCifgPeepholeProjectionNoClippingLayerNormLstmTest {};
|
||||||
|
|
||||||
|
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
||||||
HybridLayerNormLstmBlackBoxTestInt8) {
|
HybridLayerNormLstmBlackBoxTestInt8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
@ -1808,22 +1847,24 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true, GetParam());
|
||||||
|
|
||||||
|
// Goldens are calculated from weight_type=TensorType_FLOAT32.
|
||||||
lstm_golden_output_ = {{
|
lstm_golden_output_ = {{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
0.0244077, 0.128027, -0.00170918, // seq 0
|
||||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
0.0137642, 0.140751, 0.0395835, // seq 1
|
||||||
-0.0048839, 0.155096, 0.0840309, // seq 2
|
-0.00459233, 0.155278, 0.0837378, // seq 2
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||||
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
-0.00692428, 0.0848741, 0.063445, // seq 0
|
||||||
-0.00448382, 0.139278, 0.0737372, // seq 1
|
-0.00403911, 0.139963, 0.072681, // seq 1
|
||||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
0.00752708, 0.161903, 0.0561371, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
|
||||||
|
/*tolerance=*/1.06e-3);
|
||||||
}
|
}
|
||||||
|
|
||||||
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
|
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
|
||||||
@ -1940,8 +1981,11 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
HybridLayerNormLstmBlackBoxTestUint8) {
|
HybridLayerNormLstmBlackBoxTestUint8) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
@ -1988,7 +2032,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true, GetParam());
|
||||||
|
|
||||||
// Verify the final output.
|
// Verify the final output.
|
||||||
lstm_golden_output_ = {
|
lstm_golden_output_ = {
|
||||||
@ -2009,7 +2053,10 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*tolerance=*/0.000902065);
|
/*tolerance=*/0.000902065);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test
|
||||||
|
: public CifgPeepholeProjectionNoClippingLayerNormLstmTest {};
|
||||||
|
|
||||||
|
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
||||||
HybridLayerNormLstmBlackBoxTestInt8) {
|
HybridLayerNormLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
@ -2057,24 +2104,24 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true, GetParam());
|
||||||
|
|
||||||
// Verify the final output.
|
// Goldens are results using FLOAT32 inference.
|
||||||
lstm_golden_output_ = {
|
lstm_golden_output_ = {{
|
||||||
{
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
0.0212971, 0.140816, 0.0112733, // seq 0
|
||||||
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
0.0132302, 0.152308, 0.0346313, // seq 1
|
||||||
0.0130806509, 0.152660668, 0.0347516984, // seq 1
|
-0.0123688, 0.16579, 0.0893078, // seq 2
|
||||||
-0.0124010444, 0.166042402, 0.0898982584, // seq 2
|
},
|
||||||
},
|
{
|
||||||
{
|
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
-0.0226351, 0.0916948, 0.0769176, // seq 0
|
||||||
-0.0228835996, 0.0917588323, 0.0778886303, // seq 0
|
-0.0269967, 0.149708, 0.0941492, // seq 1
|
||||||
-0.0275101066, 0.148769245, 0.0938384682, // seq 1
|
-0.0103429, 0.173016, 0.0720509, // seq 2
|
||||||
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
}};
|
||||||
}};
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
|
||||||
|
/*tolerance=*/1e-3);
|
||||||
}
|
}
|
||||||
|
|
||||||
class LSTMIntegerOpModel : public SingleOpModel {
|
class LSTMIntegerOpModel : public SingleOpModel {
|
||||||
@ -3311,5 +3358,22 @@ TEST(LSTMOpModel, InvalidTypeTest) {
|
|||||||
"");
|
"");
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define QUANTIZE_PARAMETER_TEST(test) \
|
||||||
|
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::Bool())
|
||||||
|
|
||||||
|
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
|
||||||
|
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
|
||||||
|
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmInt8Test);
|
||||||
|
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(
|
||||||
|
NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
|
||||||
|
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
|
||||||
|
#undef QUANTIZE_PARAMETER_TEST
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -43,6 +43,7 @@ struct OpData {
|
|||||||
int effective_scale_1_b;
|
int effective_scale_1_b;
|
||||||
int32 effective_scale_2_a;
|
int32 effective_scale_2_a;
|
||||||
int effective_scale_2_b;
|
int effective_scale_2_b;
|
||||||
|
bool compute_row_sums = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -61,8 +62,8 @@ constexpr int kOutputTensor = 0;
|
|||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
op_data->float_weights_time_initialized = false;
|
op_data->float_weights_time_initialized = false;
|
||||||
// Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise.
|
// Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise.
|
||||||
context->AddTensors(context, /*tensors_to_add=*/4,
|
context->AddTensors(context, /*tensors_to_add=*/6,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -130,7 +131,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Resize scratch.
|
// Resize scratch.
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(4);
|
node->temporaries = TfLiteIntArrayCreate(6);
|
||||||
} else if (is_full_integer) {
|
} else if (is_full_integer) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(2);
|
node->temporaries = TfLiteIntArrayCreate(2);
|
||||||
} else {
|
} else {
|
||||||
@ -156,6 +157,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scratch_size_array));
|
scratch_size_array));
|
||||||
|
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
|
op_data->compute_row_sums = true;
|
||||||
// Tell interpreter to allocate temporary tensors to store quantized values
|
// Tell interpreter to allocate temporary tensors to store quantized values
|
||||||
// of input tensors.
|
// of input tensors.
|
||||||
node->temporaries->data[1] = scratch_tensor_index + 1;
|
node->temporaries->data[1] = scratch_tensor_index + 1;
|
||||||
@ -195,6 +197,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->ResizeTensor(context, float_weights_time,
|
context->ResizeTensor(context, float_weights_time,
|
||||||
float_weights_time_size));
|
float_weights_time_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node->temporaries->data[4] = scratch_tensor_index + 4;
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||||
|
zero_points->type = kTfLiteFloat32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
int zero_points_dims[1] = {batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = zero_points_dims[0];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
node->temporaries->data[5] = scratch_tensor_index + 5;
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||||
|
row_sums->type = kTfLiteFloat32;
|
||||||
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int row_sums_dims[1] = {num_filters};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
|
||||||
|
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
|
||||||
|
row_sums_size->data[0] = row_sums_dims[0];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (is_full_integer) {
|
if (is_full_integer) {
|
||||||
// Allocated one extra tensor.
|
// Allocated one extra tensor.
|
||||||
@ -267,7 +293,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, /*index=*/2);
|
GetTemporary(context, node, /*index=*/2);
|
||||||
TfLiteTensor* float_weights_time =
|
TfLiteTensor* float_weights_time =
|
||||||
GetTemporary(context, node, /*index=*/3);
|
GetTemporary(context, node, /*index=*/3);
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||||
// Dequantize weights time.
|
// Dequantize weights time.
|
||||||
// TODO(alanchiao): this dequantization initialization only needs to
|
// TODO(alanchiao): this dequantization initialization only needs to
|
||||||
// happen once per model and should theoretically be placed in either
|
// happen once per model and should theoretically be placed in either
|
||||||
@ -285,10 +312,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
op_data->float_weights_time_initialized = true;
|
op_data->float_weights_time_initialized = true;
|
||||||
}
|
}
|
||||||
reference_ops::EvalHybridSVDF(context, node, input, weights_feature,
|
|
||||||
float_weights_time, bias, params, scratch,
|
reference_ops::EvalHybridSVDF(
|
||||||
scaling_factors, input_quantized,
|
context, node, input, weights_feature, float_weights_time, bias,
|
||||||
activation_state, output);
|
params, scratch, scaling_factors, input_quantized, activation_state,
|
||||||
|
output, zero_points, row_sums, &op_data->compute_row_sums);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
} else {
|
} else {
|
||||||
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||||
|
@ -131,7 +131,8 @@ class BaseSVDFOpModel : public SingleOpModel {
|
|||||||
BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
||||||
int rank,
|
int rank,
|
||||||
TensorType weights_feature_type = TensorType_FLOAT32,
|
TensorType weights_feature_type = TensorType_FLOAT32,
|
||||||
TensorType weights_time_type = TensorType_FLOAT32)
|
TensorType weights_time_type = TensorType_FLOAT32,
|
||||||
|
bool asymmetric_quantize_inputs = false)
|
||||||
: batches_(batches),
|
: batches_(batches),
|
||||||
units_(units),
|
units_(units),
|
||||||
input_size_(input_size),
|
input_size_(input_size),
|
||||||
@ -146,9 +147,10 @@ class BaseSVDFOpModel : public SingleOpModel {
|
|||||||
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
|
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
|
||||||
/*is_variable=*/true);
|
/*is_variable=*/true);
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
SetBuiltinOp(
|
SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
|
||||||
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
|
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE,
|
||||||
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
|
asymmetric_quantize_inputs)
|
||||||
|
.Union());
|
||||||
BuildInterpreter({
|
BuildInterpreter({
|
||||||
{batches_, input_size_}, // input tensor
|
{batches_, input_size_}, // input tensor
|
||||||
{units_ * rank, input_size_}, // weights_feature tensor
|
{units_ * rank, input_size_}, // weights_feature tensor
|
||||||
@ -203,9 +205,10 @@ class SVDFOpModel : public BaseSVDFOpModel {
|
|||||||
class HybridSVDFOpModel : public BaseSVDFOpModel {
|
class HybridSVDFOpModel : public BaseSVDFOpModel {
|
||||||
public:
|
public:
|
||||||
HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
||||||
int rank, TensorType tensor_type)
|
int rank, TensorType tensor_type,
|
||||||
|
bool asymmetric_quantize_inputs)
|
||||||
: BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
|
: BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
|
||||||
tensor_type, tensor_type) {
|
tensor_type, tensor_type, asymmetric_quantize_inputs) {
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,7 +232,7 @@ class HybridSVDFOpModel : public BaseSVDFOpModel {
|
|||||||
TensorType tensor_type_;
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SVDFOpTest : public ::testing::Test {
|
class SVDFOpTest : public ::testing::TestWithParam<bool> {
|
||||||
protected:
|
protected:
|
||||||
void VerifyGoldens(float golden_input[], float golden_output[],
|
void VerifyGoldens(float golden_input[], float golden_output[],
|
||||||
int golden_size, BaseSVDFOpModel* svdf,
|
int golden_size, BaseSVDFOpModel* svdf,
|
||||||
@ -262,6 +265,9 @@ class SVDFOpTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest,
|
||||||
|
::testing::ValuesIn({false, true}));
|
||||||
|
|
||||||
TEST_F(SVDFOpTest, BlackBoxTestRank1) {
|
TEST_F(SVDFOpTest, BlackBoxTestRank1) {
|
||||||
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/1);
|
/*memory_size=*/10, /*rank=*/1);
|
||||||
@ -325,9 +331,10 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) {
|
|||||||
&svdf);
|
&svdf);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
|
TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/1, TensorType_UINT8);
|
/*memory_size=*/10, /*rank=*/1, TensorType_UINT8,
|
||||||
|
GetParam());
|
||||||
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
||||||
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
||||||
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
||||||
@ -347,12 +354,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.002945);
|
/*tolerance=*/0.004285);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
|
TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/2, TensorType_UINT8);
|
/*memory_size=*/10, /*rank=*/2, TensorType_UINT8,
|
||||||
|
GetParam());
|
||||||
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
||||||
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
||||||
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
||||||
@ -387,12 +395,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.00625109);
|
/*tolerance=*/0.007175);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
|
TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/1, TensorType_INT8);
|
/*memory_size=*/10, /*rank=*/1, TensorType_INT8,
|
||||||
|
GetParam());
|
||||||
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
||||||
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
||||||
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
||||||
@ -412,12 +421,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.002945);
|
/*tolerance=*/0.004285);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
|
TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/2, TensorType_INT8);
|
/*memory_size=*/10, /*rank=*/2, TensorType_INT8,
|
||||||
|
GetParam());
|
||||||
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
||||||
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
||||||
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
||||||
@ -452,7 +462,7 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.00625109);
|
/*tolerance=*/0.007175);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test case for full integer quantization of SVDF.
|
// Test case for full integer quantization of SVDF.
|
||||||
|
@ -33,6 +33,7 @@ struct OpData {
|
|||||||
bool is_layer_norm_lstm;
|
bool is_layer_norm_lstm;
|
||||||
// The scratch tensor index.
|
// The scratch tensor index.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
|
bool compute_row_sums = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Input Tensors of size {max_time, n_batch, n_input}
|
// Input Tensors of size {max_time, n_batch, n_input}
|
||||||
@ -92,7 +93,9 @@ enum TemporaryTensor {
|
|||||||
kProductScalingFactors = 5,
|
kProductScalingFactors = 5,
|
||||||
kRecoveredCellWeights = 6,
|
kRecoveredCellWeights = 6,
|
||||||
kAccumScratch = 7,
|
kAccumScratch = 7,
|
||||||
kNumTemporaryTensors
|
kZeroPoints = 8,
|
||||||
|
kRowSums = 9,
|
||||||
|
kNumTemporaryTensors = 10
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
@ -408,6 +411,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scratch_buffer_size));
|
scratch_buffer_size));
|
||||||
|
|
||||||
if (IsHybridOp(input, input_to_output_weights)) {
|
if (IsHybridOp(input, input_to_output_weights)) {
|
||||||
|
op_data->compute_row_sums = true;
|
||||||
// Allocate temporary tensors to store quantized values of input,
|
// Allocate temporary tensors to store quantized values of input,
|
||||||
// activation_state and cell_state tensors.
|
// activation_state and cell_state tensors.
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
@ -515,6 +519,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints;
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||||
|
zero_points->type = kTfLiteFloat32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = n_batch;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||||
|
row_sums->type = kTfLiteInt32;
|
||||||
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int row_sums_rows = use_cifg ? 6 : 8;
|
||||||
|
const TfLiteTensor* projection_weights =
|
||||||
|
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
||||||
|
if (projection_weights != nullptr) {
|
||||||
|
row_sums_rows += ceil(n_output / n_cell);
|
||||||
|
}
|
||||||
|
int row_sums_dims[2] = {row_sums_rows, n_cell};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
||||||
|
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
row_sums_size->data[0] = row_sums_dims[0];
|
||||||
|
row_sums_size->data[1] = row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -600,6 +632,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
lstm_params.activation = params->activation;
|
lstm_params.activation = params->activation;
|
||||||
lstm_params.cell_clip = params->cell_clip;
|
lstm_params.cell_clip = params->cell_clip;
|
||||||
lstm_params.proj_clip = params->proj_clip;
|
lstm_params.proj_clip = params->proj_clip;
|
||||||
|
lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
|
||||||
|
|
||||||
switch (input_to_output_weights->type) {
|
switch (input_to_output_weights->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
@ -623,6 +656,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
|
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||||
TfLiteTensor* activation_state_quantized =
|
TfLiteTensor* activation_state_quantized =
|
||||||
GetTemporary(context, node, /*index=*/2);
|
GetTemporary(context, node, /*index=*/2);
|
||||||
@ -635,6 +669,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, /*index=*/6);
|
GetTemporary(context, node, /*index=*/6);
|
||||||
TfLiteTensor* accum_scratch =
|
TfLiteTensor* accum_scratch =
|
||||||
GetTemporary(context, node, /*index=*/kAccumScratch);
|
GetTemporary(context, node, /*index=*/kAccumScratch);
|
||||||
|
TfLiteTensor* zero_points =
|
||||||
|
GetTemporary(context, node, /*index=*/kZeroPoints);
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/kRowSums);
|
||||||
|
const int row_sums_size = row_sums->dims->data[0];
|
||||||
return lstm_eval::EvalHybrid(
|
return lstm_eval::EvalHybrid(
|
||||||
input, input_to_input_weights, input_to_forget_weights,
|
input, input_to_input_weights, input_to_forget_weights,
|
||||||
input_to_cell_weights, input_to_output_weights,
|
input_to_cell_weights, input_to_output_weights,
|
||||||
@ -654,7 +692,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||||
cell_state_quantized, activation_state, cell_state, accum_scratch,
|
cell_state_quantized, activation_state, cell_state, accum_scratch,
|
||||||
output, CpuBackendContext::GetFromContext(context));
|
output, zero_points, row_sums, row_sums_size,
|
||||||
|
&op_data->compute_row_sums,
|
||||||
|
CpuBackendContext::GetFromContext(context));
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d is not currently supported.",
|
context->ReportError(context, "Type %d is not currently supported.",
|
||||||
|
@ -38,7 +38,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
float proj_clip,
|
float proj_clip,
|
||||||
const std::vector<std::vector<int>>& input_shapes,
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
const TensorType& weights_type = TensorType_FLOAT32,
|
const TensorType& weights_type = TensorType_FLOAT32,
|
||||||
bool is_layer_norm = false)
|
bool is_layer_norm = false,
|
||||||
|
bool asymmetric_quantize_inputs = false)
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_cell_(n_cell),
|
n_cell_(n_cell),
|
||||||
@ -131,7 +132,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
||||||
CreateUnidirectionalSequenceLSTMOptions(
|
CreateUnidirectionalSequenceLSTMOptions(
|
||||||
builder_, ActivationFunctionType_TANH, cell_clip,
|
builder_, ActivationFunctionType_TANH, cell_clip,
|
||||||
proj_clip, time_major)
|
proj_clip, time_major, asymmetric_quantize_inputs)
|
||||||
.Union());
|
.Union());
|
||||||
BuildInterpreter(input_shapes);
|
BuildInterpreter(input_shapes);
|
||||||
}
|
}
|
||||||
@ -292,11 +293,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
bool time_major, bool use_cifg, bool use_peephole,
|
bool time_major, bool use_cifg, bool use_peephole,
|
||||||
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
||||||
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
||||||
TensorType tensor_type)
|
TensorType tensor_type, bool asymmetric_quantize_inputs)
|
||||||
: UnidirectionalLSTMOpModel(
|
: UnidirectionalLSTMOpModel(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
||||||
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
||||||
cell_clip, proj_clip, input_shapes, tensor_type) {
|
cell_clip, proj_clip, input_shapes, tensor_type, false,
|
||||||
|
asymmetric_quantize_inputs) {
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,7 +362,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
TensorType tensor_type_;
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseUnidirectionalLstmTest : public ::testing::Test {
|
class BaseUnidirectionalLstmTest : public ::testing::TestWithParam<bool> {
|
||||||
protected:
|
protected:
|
||||||
// Weights of the LSTM model. Some are optional.
|
// Weights of the LSTM model. Some are optional.
|
||||||
std::vector<float> input_to_input_weights_;
|
std::vector<float> input_to_input_weights_;
|
||||||
@ -626,7 +628,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
/*time_major=*/false);
|
/*time_major=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -668,7 +670,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8);
|
TensorType_UINT8, GetParam());
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -689,7 +691,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -731,7 +733,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_INT8, GetParam());
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -862,7 +864,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -880,11 +882,10 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{
|
{
|
||||||
{sequence_length, n_batch, n_input}, // input tensor
|
{sequence_length, n_batch, n_input}, // input tensor
|
||||||
|
|
||||||
{0, 0}, // input_to_input_weight tensor
|
{0, 0}, // input_to_input_weight tensor
|
||||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||||
{n_cell, n_input}, // input_to_output_weight tensor
|
{n_cell, n_input}, // input_to_output_weight tensor
|
||||||
|
|
||||||
{0, 0}, // recurrent_to_input_weight tensor
|
{0, 0}, // recurrent_to_input_weight tensor
|
||||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||||
@ -905,7 +906,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8);
|
TensorType_UINT8, GetParam());
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
@ -925,7 +926,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -968,7 +969,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_INT8, GetParam());
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
@ -1655,14 +1656,16 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
const int n_output = 16;
|
const int n_output = 16;
|
||||||
const int sequence_length = 4;
|
const int sequence_length = 4;
|
||||||
|
if (GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
HybridUnidirectionalLSTMOpModel lstm(
|
HybridUnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
@ -1697,7 +1700,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8);
|
TensorType_UINT8, GetParam());
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -1723,8 +1726,11 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
|
if (GetParam()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1765,7 +1771,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_INT8, GetParam());
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -2737,5 +2743,14 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define QUANTIZE_PARAMETER_TEST(test) \
|
||||||
|
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true}));
|
||||||
|
|
||||||
|
QUANTIZE_PARAMETER_TEST(
|
||||||
|
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(
|
||||||
|
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
|
||||||
|
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest);
|
||||||
|
#undef QUANTIZE_PARAMETER_TEST
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -26,6 +26,15 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace unidirectional_sequence_rnn {
|
namespace unidirectional_sequence_rnn {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int scratch_tensor_index;
|
||||||
|
bool compute_row_sums = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Input tensors.
|
// Input tensors.
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
constexpr int kWeightsTensor = 1;
|
constexpr int kWeightsTensor = 1;
|
||||||
@ -37,13 +46,14 @@ constexpr int kHiddenStateTensor = 4;
|
|||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* scratch_tensor_index = new int;
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
|
context->AddTensors(context, /*tensors_to_add=*/6,
|
||||||
return scratch_tensor_index;
|
&op_data->scratch_tensor_index);
|
||||||
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<int*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -96,10 +106,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Allocate temporary tensors to store quantized values of input and
|
// Allocate temporary tensors to store quantized values of input and
|
||||||
// hidden_state tensors.
|
// hidden_state tensors.
|
||||||
if (is_hybrid) {
|
if (is_hybrid) {
|
||||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
op_data->compute_row_sums = true;
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
node->temporaries = TfLiteIntArrayCreate(3);
|
node->temporaries = TfLiteIntArrayCreate(6);
|
||||||
node->temporaries->data[0] = *scratch_tensor_index;
|
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
input_quantized->type = input_weights->type;
|
input_quantized->type = input_weights->type;
|
||||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||||
@ -108,7 +119,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
||||||
input_quantized_size));
|
input_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[1] = *scratch_tensor_index + 1;
|
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||||
TfLiteTensor* hidden_state_quantized =
|
TfLiteTensor* hidden_state_quantized =
|
||||||
GetTemporary(context, node, /*index=*/1);
|
GetTemporary(context, node, /*index=*/1);
|
||||||
hidden_state_quantized->type = input_weights->type;
|
hidden_state_quantized->type = input_weights->type;
|
||||||
@ -121,7 +132,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->ResizeTensor(context, hidden_state_quantized,
|
context->ResizeTensor(context, hidden_state_quantized,
|
||||||
hidden_state_quantized_size));
|
hidden_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[2] = *scratch_tensor_index + 2;
|
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||||
@ -132,6 +143,42 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
|
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
|
||||||
|
accum_scratch->type = kTfLiteInt32;
|
||||||
|
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||||
|
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||||
|
accum_scratch_dims)) {
|
||||||
|
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
||||||
|
accum_scratch_size->data[0] = accum_scratch_dims[0];
|
||||||
|
accum_scratch_size->data[1] = accum_scratch_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
||||||
|
accum_scratch_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
||||||
|
zero_points->type = kTfLiteInt32;
|
||||||
|
zero_points->allocation_type = kTfLiteArenaRw;
|
||||||
|
int zero_points_dims[1] = {batch_size};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
||||||
|
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||||
|
zero_points_size->data[0] = batch_size;
|
||||||
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||||
|
zero_points_size));
|
||||||
|
}
|
||||||
|
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
||||||
|
row_sums->type = kTfLiteInt32;
|
||||||
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
int row_sums_dims[2] = {2, num_units};
|
||||||
|
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
||||||
|
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
||||||
|
row_sums_size->data[0] = row_sums_dims[0];
|
||||||
|
row_sums_size->data[1] = row_sums_dims[1];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -202,7 +249,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
|
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
|
||||||
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
|
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
|
||||||
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
|
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
|
||||||
TfLiteTensor* hidden_state, TfLiteTensor* output) {
|
TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points,
|
||||||
|
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
||||||
|
bool* compute_row_sums) {
|
||||||
const bool time_major = params->time_major;
|
const bool time_major = params->time_major;
|
||||||
const int batch_size =
|
const int batch_size =
|
||||||
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
||||||
@ -227,6 +276,14 @@ TfLiteStatus EvalHybrid(
|
|||||||
float input_weights_scale = input_weights->params.scale;
|
float input_weights_scale = input_weights->params.scale;
|
||||||
float recurrent_weights_scale = recurrent_weights->params.scale;
|
float recurrent_weights_scale = recurrent_weights->params.scale;
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
|
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
|
||||||
|
int32_t* zero_points_ptr = nullptr;
|
||||||
|
int32_t* row_sums_ptr = nullptr;
|
||||||
|
|
||||||
|
if (params->asymmetric_quantize_inputs) {
|
||||||
|
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
||||||
|
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||||
|
}
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
// Initialize the pointer to hidden state.
|
// Initialize the pointer to hidden state.
|
||||||
@ -244,7 +301,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
||||||
num_units, batch_size, num_units, params->activation,
|
num_units, batch_size, num_units, params->activation,
|
||||||
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
hidden_state_ptr_batch, output_ptr_batch);
|
hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr,
|
||||||
|
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For each batch
|
// For each batch
|
||||||
@ -259,13 +318,14 @@ TfLiteStatus EvalHybrid(
|
|||||||
s * input_size;
|
s * input_size;
|
||||||
float* output_ptr_batch = GetTensorData<float>(output) +
|
float* output_ptr_batch = GetTensorData<float>(output) +
|
||||||
b * num_units * max_time + s * num_units;
|
b * num_units * max_time + s * num_units;
|
||||||
|
|
||||||
kernel_utils::RnnBatchStep(
|
kernel_utils::RnnBatchStep(
|
||||||
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
||||||
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
|
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
|
||||||
input_size, num_units, /*batch_size=*/1, num_units,
|
input_size, num_units, /*batch_size=*/1, num_units,
|
||||||
params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
|
params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
|
||||||
scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch);
|
scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch,
|
||||||
|
params->asymmetric_quantize_inputs, zero_points_ptr,
|
||||||
|
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,7 +334,6 @@ TfLiteStatus EvalHybrid(
|
|||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_weights =
|
const TfLiteTensor* recurrent_weights =
|
||||||
@ -292,12 +351,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
// TODO(mirkov): implement eval with quantized inputs as well.
|
// TODO(mirkov): implement eval with quantized inputs as well.
|
||||||
|
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
||||||
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
|
||||||
|
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
|
||||||
|
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
|
||||||
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
||||||
input_quantized, hidden_state_quantized,
|
input_quantized, hidden_state_quantized,
|
||||||
scaling_factors, hidden_state, output);
|
scaling_factors, hidden_state, output, zero_points,
|
||||||
|
accum_scratch, row_sums, &op_data->compute_row_sums);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d not currently supported.",
|
context->ReportError(context, "Type %d not currently supported.",
|
||||||
|
@ -174,7 +174,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
UnidirectionalRNNOpModel(
|
UnidirectionalRNNOpModel(
|
||||||
int batches, int sequence_len, int units, int size, bool time_major,
|
int batches, int sequence_len, int units, int size, bool time_major,
|
||||||
const TensorType& weights = TensorType_FLOAT32,
|
const TensorType& weights = TensorType_FLOAT32,
|
||||||
const TensorType& recurrent_weights = TensorType_FLOAT32)
|
const TensorType& recurrent_weights = TensorType_FLOAT32,
|
||||||
|
bool asymmetric_quantize_inputs = false)
|
||||||
: batches_(batches),
|
: batches_(batches),
|
||||||
sequence_len_(sequence_len),
|
sequence_len_(sequence_len),
|
||||||
units_(units),
|
units_(units),
|
||||||
@ -188,7 +189,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
||||||
BuiltinOptions_SequenceRNNOptions,
|
BuiltinOptions_SequenceRNNOptions,
|
||||||
CreateSequenceRNNOptions(builder_, time_major,
|
CreateSequenceRNNOptions(builder_, time_major,
|
||||||
ActivationFunctionType_RELU)
|
ActivationFunctionType_RELU,
|
||||||
|
asymmetric_quantize_inputs)
|
||||||
.Union());
|
.Union());
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
BuildInterpreter({{sequence_len_, batches_, input_size_},
|
BuildInterpreter({{sequence_len_, batches_, input_size_},
|
||||||
@ -249,9 +251,11 @@ class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel {
|
|||||||
public:
|
public:
|
||||||
HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
|
HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
|
||||||
int size, bool time_major,
|
int size, bool time_major,
|
||||||
TensorType tensor_type)
|
TensorType tensor_type,
|
||||||
|
bool asymmetric_quantize_inputs)
|
||||||
: UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
|
: UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
|
||||||
tensor_type, tensor_type) {
|
tensor_type, tensor_type,
|
||||||
|
asymmetric_quantize_inputs) {
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,10 +301,14 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
|
|||||||
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
|
class HybridUnidirectionalRNNOpModelOpTest
|
||||||
|
: public ::testing::TestWithParam<bool> {};
|
||||||
|
|
||||||
|
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/false, TensorType_UINT8);
|
/*time_major=*/false, TensorType_UINT8,
|
||||||
|
GetParam());
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -323,10 +331,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
|
|||||||
expected, /*max_abs_error=*/0.013)));
|
expected, /*max_abs_error=*/0.013)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) {
|
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/false, TensorType_INT8);
|
/*time_major=*/false, TensorType_INT8,
|
||||||
|
GetParam());
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -378,10 +387,11 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
|
|||||||
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
|
TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/true, TensorType_UINT8);
|
/*time_major=*/true, TensorType_UINT8,
|
||||||
|
GetParam());
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -408,10 +418,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
|
|||||||
expected, /*max_abs_error=*/0.013)));
|
expected, /*max_abs_error=*/0.013)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
|
TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/true, TensorType_INT8);
|
/*time_major=*/true, TensorType_INT8,
|
||||||
|
GetParam());
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -438,5 +449,9 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
|
|||||||
expected, /*max_abs_error=*/0.013)));
|
expected, /*max_abs_error=*/0.013)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(HybridUnidirectionalRNNOpModelOpTest,
|
||||||
|
HybridUnidirectionalRNNOpModelOpTest,
|
||||||
|
::testing::ValuesIn({true, false}));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -519,17 +519,22 @@ table LSHProjectionOptions {
|
|||||||
table SVDFOptions {
|
table SVDFOptions {
|
||||||
rank:int;
|
rank:int;
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
|
// For weights-only quantization, use asymmetric quantization for non
|
||||||
|
// constant inputs at evaluation time.
|
||||||
|
asymmetric_quantize_inputs:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow RNNCell.
|
// An implementation of TensorFlow RNNCell.
|
||||||
table RNNOptions {
|
table RNNOptions {
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
|
asymmetric_quantize_inputs:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow dynamic_rnn with RNNCell.
|
// An implementation of TensorFlow dynamic_rnn with RNNCell.
|
||||||
table SequenceRNNOptions {
|
table SequenceRNNOptions {
|
||||||
time_major:bool;
|
time_major:bool;
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
|
asymmetric_quantize_inputs:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
|
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
|
||||||
@ -537,6 +542,7 @@ table BidirectionalSequenceRNNOptions {
|
|||||||
time_major:bool;
|
time_major:bool;
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
merge_outputs: bool;
|
merge_outputs: bool;
|
||||||
|
asymmetric_quantize_inputs:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum FullyConnectedOptionsWeightsFormat: byte {
|
enum FullyConnectedOptionsWeightsFormat: byte {
|
||||||
@ -556,6 +562,11 @@ table FullyConnectedOptions {
|
|||||||
// If set to true, then the number of dimension is preserved. Furthermore,
|
// If set to true, then the number of dimension is preserved. Furthermore,
|
||||||
// all but the last dimension of the input and output shapes will be equal.
|
// all but the last dimension of the input and output shapes will be equal.
|
||||||
keep_num_dims: bool;
|
keep_num_dims: bool;
|
||||||
|
|
||||||
|
// Parameters for FullyConnected version 7 or above.
|
||||||
|
// If set to true, then weights-only op will use asymmetric quantization for
|
||||||
|
// inputs.
|
||||||
|
asymmetric_quantize_inputs: bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
table SoftmaxOptions {
|
table SoftmaxOptions {
|
||||||
@ -604,6 +615,9 @@ table LSTMOptions {
|
|||||||
// Parameters for LSTM version 2 or above.
|
// Parameters for LSTM version 2 or above.
|
||||||
// Basic kernel is only supported in version 2 or above.
|
// Basic kernel is only supported in version 2 or above.
|
||||||
kernel_type: LSTMKernelType = FULL;
|
kernel_type: LSTMKernelType = FULL;
|
||||||
|
|
||||||
|
// Parameters for LSTM version 4 or above.
|
||||||
|
asymmetric_quantize_inputs: bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow dynamic_rnn with LSTMCell.
|
// An implementation of TensorFlow dynamic_rnn with LSTMCell.
|
||||||
@ -614,6 +628,9 @@ table UnidirectionalSequenceLSTMOptions {
|
|||||||
|
|
||||||
// If true then first dimension is sequence, otherwise batch.
|
// If true then first dimension is sequence, otherwise batch.
|
||||||
time_major:bool;
|
time_major:bool;
|
||||||
|
|
||||||
|
// Parameter for Unidirectional Sequence LSTM version 4.
|
||||||
|
asymmetric_quantize_inputs:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
table BidirectionalSequenceLSTMOptions {
|
table BidirectionalSequenceLSTMOptions {
|
||||||
@ -630,6 +647,9 @@ table BidirectionalSequenceLSTMOptions {
|
|||||||
// Version 1 implementations assumed time_major to be true, so this default
|
// Version 1 implementations assumed time_major to be true, so this default
|
||||||
// value should never change.
|
// value should never change.
|
||||||
time_major: bool = true;
|
time_major: bool = true;
|
||||||
|
|
||||||
|
// Parameters for version 3 or above.
|
||||||
|
asymmetric_quantize_inputs:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
table ResizeBilinearOptions {
|
table ResizeBilinearOptions {
|
||||||
|
@ -4216,9 +4216,11 @@ struct SVDFOptionsT : public flatbuffers::NativeTable {
|
|||||||
typedef SVDFOptions TableType;
|
typedef SVDFOptions TableType;
|
||||||
int32_t rank;
|
int32_t rank;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
SVDFOptionsT()
|
SVDFOptionsT()
|
||||||
: rank(0),
|
: rank(0),
|
||||||
fused_activation_function(tflite::ActivationFunctionType_NONE) {
|
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4226,7 +4228,8 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
typedef SVDFOptionsT NativeTableType;
|
typedef SVDFOptionsT NativeTableType;
|
||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_RANK = 4,
|
VT_RANK = 4,
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 6
|
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
|
||||||
};
|
};
|
||||||
int32_t rank() const {
|
int32_t rank() const {
|
||||||
return GetField<int32_t>(VT_RANK, 0);
|
return GetField<int32_t>(VT_RANK, 0);
|
||||||
@ -4234,10 +4237,14 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int32_t>(verifier, VT_RANK) &&
|
VerifyField<int32_t>(verifier, VT_RANK) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4254,6 +4261,9 @@ struct SVDFOptionsBuilder {
|
|||||||
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
||||||
fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4269,9 +4279,11 @@ struct SVDFOptionsBuilder {
|
|||||||
inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(
|
inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
int32_t rank = 0,
|
int32_t rank = 0,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
SVDFOptionsBuilder builder_(_fbb);
|
SVDFOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_rank(rank);
|
builder_.add_rank(rank);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
}
|
}
|
||||||
@ -4281,22 +4293,29 @@ flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBufferBuilde
|
|||||||
struct RNNOptionsT : public flatbuffers::NativeTable {
|
struct RNNOptionsT : public flatbuffers::NativeTable {
|
||||||
typedef RNNOptions TableType;
|
typedef RNNOptions TableType;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
RNNOptionsT()
|
RNNOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE) {
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
typedef RNNOptionsT NativeTableType;
|
typedef RNNOptionsT NativeTableType;
|
||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 4
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 6
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4310,6 +4329,9 @@ struct RNNOptionsBuilder {
|
|||||||
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
||||||
fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4324,8 +4346,10 @@ struct RNNOptionsBuilder {
|
|||||||
|
|
||||||
inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
|
inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
RNNOptionsBuilder builder_(_fbb);
|
RNNOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
}
|
}
|
||||||
@ -4336,9 +4360,11 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable {
|
|||||||
typedef SequenceRNNOptions TableType;
|
typedef SequenceRNNOptions TableType;
|
||||||
bool time_major;
|
bool time_major;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
SequenceRNNOptionsT()
|
SequenceRNNOptionsT()
|
||||||
: time_major(false),
|
: time_major(false),
|
||||||
fused_activation_function(tflite::ActivationFunctionType_NONE) {
|
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4346,7 +4372,8 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
typedef SequenceRNNOptionsT NativeTableType;
|
typedef SequenceRNNOptionsT NativeTableType;
|
||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_TIME_MAJOR = 4,
|
VT_TIME_MAJOR = 4,
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 6
|
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
|
||||||
};
|
};
|
||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
@ -4354,10 +4381,14 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4374,6 +4405,9 @@ struct SequenceRNNOptionsBuilder {
|
|||||||
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
||||||
fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4389,8 +4423,10 @@ struct SequenceRNNOptionsBuilder {
|
|||||||
inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
|
inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
bool time_major = false,
|
bool time_major = false,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
SequenceRNNOptionsBuilder builder_(_fbb);
|
SequenceRNNOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
@ -4403,10 +4439,12 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable {
|
|||||||
bool time_major;
|
bool time_major;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
bool merge_outputs;
|
bool merge_outputs;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
BidirectionalSequenceRNNOptionsT()
|
BidirectionalSequenceRNNOptionsT()
|
||||||
: time_major(false),
|
: time_major(false),
|
||||||
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
merge_outputs(false) {
|
merge_outputs(false),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4415,7 +4453,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
|
|||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_TIME_MAJOR = 4,
|
VT_TIME_MAJOR = 4,
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
||||||
VT_MERGE_OUTPUTS = 8
|
VT_MERGE_OUTPUTS = 8,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
|
||||||
};
|
};
|
||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
@ -4426,11 +4465,15 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
|
|||||||
bool merge_outputs() const {
|
bool merge_outputs() const {
|
||||||
return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
|
return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4450,6 +4493,9 @@ struct BidirectionalSequenceRNNOptionsBuilder {
|
|||||||
void add_merge_outputs(bool merge_outputs) {
|
void add_merge_outputs(bool merge_outputs) {
|
||||||
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
|
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4466,8 +4512,10 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
|
|||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
bool time_major = false,
|
bool time_major = false,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
bool merge_outputs = false) {
|
bool merge_outputs = false,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
|
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_merge_outputs(merge_outputs);
|
builder_.add_merge_outputs(merge_outputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
@ -4481,10 +4529,12 @@ struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
|
|||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
tflite::FullyConnectedOptionsWeightsFormat weights_format;
|
tflite::FullyConnectedOptionsWeightsFormat weights_format;
|
||||||
bool keep_num_dims;
|
bool keep_num_dims;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
FullyConnectedOptionsT()
|
FullyConnectedOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT),
|
weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT),
|
||||||
keep_num_dims(false) {
|
keep_num_dims(false),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4493,7 +4543,8 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
|
|||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_WEIGHTS_FORMAT = 6,
|
VT_WEIGHTS_FORMAT = 6,
|
||||||
VT_KEEP_NUM_DIMS = 8
|
VT_KEEP_NUM_DIMS = 8,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -4504,11 +4555,15 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
|
|||||||
bool keep_num_dims() const {
|
bool keep_num_dims() const {
|
||||||
return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0;
|
return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0;
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
|
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) &&
|
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4528,6 +4583,9 @@ struct FullyConnectedOptionsBuilder {
|
|||||||
void add_keep_num_dims(bool keep_num_dims) {
|
void add_keep_num_dims(bool keep_num_dims) {
|
||||||
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0);
|
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4544,8 +4602,10 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
|
|||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
|
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
|
||||||
bool keep_num_dims = false) {
|
bool keep_num_dims = false,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
FullyConnectedOptionsBuilder builder_(_fbb);
|
FullyConnectedOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_keep_num_dims(keep_num_dims);
|
builder_.add_keep_num_dims(keep_num_dims);
|
||||||
builder_.add_weights_format(weights_format);
|
builder_.add_weights_format(weights_format);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
@ -4932,11 +4992,13 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
float cell_clip;
|
float cell_clip;
|
||||||
float proj_clip;
|
float proj_clip;
|
||||||
tflite::LSTMKernelType kernel_type;
|
tflite::LSTMKernelType kernel_type;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
LSTMOptionsT()
|
LSTMOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f),
|
proj_clip(0.0f),
|
||||||
kernel_type(tflite::LSTMKernelType_FULL) {
|
kernel_type(tflite::LSTMKernelType_FULL),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4946,7 +5008,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8,
|
VT_PROJ_CLIP = 8,
|
||||||
VT_KERNEL_TYPE = 10
|
VT_KERNEL_TYPE = 10,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -4960,12 +5023,16 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
tflite::LSTMKernelType kernel_type() const {
|
tflite::LSTMKernelType kernel_type() const {
|
||||||
return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
|
return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
|
VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4988,6 +5055,9 @@ struct LSTMOptionsBuilder {
|
|||||||
void add_kernel_type(tflite::LSTMKernelType kernel_type) {
|
void add_kernel_type(tflite::LSTMKernelType kernel_type) {
|
||||||
fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
|
fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -5005,10 +5075,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
|
|||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f,
|
float proj_clip = 0.0f,
|
||||||
tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) {
|
tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
LSTMOptionsBuilder builder_(_fbb);
|
LSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_kernel_type(kernel_type);
|
builder_.add_kernel_type(kernel_type);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
@ -5022,11 +5094,13 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
float cell_clip;
|
float cell_clip;
|
||||||
float proj_clip;
|
float proj_clip;
|
||||||
bool time_major;
|
bool time_major;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
UnidirectionalSequenceLSTMOptionsT()
|
UnidirectionalSequenceLSTMOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f),
|
proj_clip(0.0f),
|
||||||
time_major(false) {
|
time_major(false),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -5036,7 +5110,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
|
|||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8,
|
VT_PROJ_CLIP = 8,
|
||||||
VT_TIME_MAJOR = 10
|
VT_TIME_MAJOR = 10,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -5050,12 +5125,16 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
|
|||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -5078,6 +5157,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
|
|||||||
void add_time_major(bool time_major) {
|
void add_time_major(bool time_major) {
|
||||||
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
|
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -5095,10 +5177,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
|
|||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f,
|
float proj_clip = 0.0f,
|
||||||
bool time_major = false) {
|
bool time_major = false,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
@ -5113,12 +5197,14 @@ struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
float proj_clip;
|
float proj_clip;
|
||||||
bool merge_outputs;
|
bool merge_outputs;
|
||||||
bool time_major;
|
bool time_major;
|
||||||
|
bool asymmetric_quantize_inputs;
|
||||||
BidirectionalSequenceLSTMOptionsT()
|
BidirectionalSequenceLSTMOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f),
|
proj_clip(0.0f),
|
||||||
merge_outputs(false),
|
merge_outputs(false),
|
||||||
time_major(true) {
|
time_major(true),
|
||||||
|
asymmetric_quantize_inputs(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -5129,7 +5215,8 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
|
|||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8,
|
VT_PROJ_CLIP = 8,
|
||||||
VT_MERGE_OUTPUTS = 10,
|
VT_MERGE_OUTPUTS = 10,
|
||||||
VT_TIME_MAJOR = 12
|
VT_TIME_MAJOR = 12,
|
||||||
|
VT_ASYMMETRIC_QUANTIZE_INPUTS = 14
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -5146,6 +5233,9 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
|
|||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0;
|
||||||
}
|
}
|
||||||
|
bool asymmetric_quantize_inputs() const {
|
||||||
|
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
@ -5153,6 +5243,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
|
|||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -5178,6 +5269,9 @@ struct BidirectionalSequenceLSTMOptionsBuilder {
|
|||||||
void add_time_major(bool time_major) {
|
void add_time_major(bool time_major) {
|
||||||
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1);
|
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1);
|
||||||
}
|
}
|
||||||
|
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
||||||
|
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
||||||
|
}
|
||||||
explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -5196,10 +5290,12 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
|
|||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f,
|
float proj_clip = 0.0f,
|
||||||
bool merge_outputs = false,
|
bool merge_outputs = false,
|
||||||
bool time_major = true) {
|
bool time_major = true,
|
||||||
|
bool asymmetric_quantize_inputs = false) {
|
||||||
BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
|
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
builder_.add_merge_outputs(merge_outputs);
|
builder_.add_merge_outputs(merge_outputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
@ -11034,6 +11130,7 @@ inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_
|
|||||||
(void)_resolver;
|
(void)_resolver;
|
||||||
{ auto _e = rank(); _o->rank = _e; }
|
{ auto _e = rank(); _o->rank = _e; }
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11046,10 +11143,12 @@ inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBuffe
|
|||||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
auto _rank = _o->rank;
|
auto _rank = _o->rank;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateSVDFOptions(
|
return tflite::CreateSVDFOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_rank,
|
_rank,
|
||||||
_fused_activation_function);
|
_fused_activation_function,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11062,6 +11161,7 @@ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_fu
|
|||||||
(void)_o;
|
(void)_o;
|
||||||
(void)_resolver;
|
(void)_resolver;
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11073,9 +11173,11 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(flatbuffers::FlatBufferB
|
|||||||
(void)_o;
|
(void)_o;
|
||||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateRNNOptions(
|
return tflite::CreateRNNOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function);
|
_fused_activation_function,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11089,6 +11191,7 @@ inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuff
|
|||||||
(void)_resolver;
|
(void)_resolver;
|
||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11101,10 +11204,12 @@ inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(flatbuff
|
|||||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateSequenceRNNOptions(
|
return tflite::CreateSequenceRNNOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_time_major,
|
_time_major,
|
||||||
_fused_activation_function);
|
_fused_activation_function,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11119,6 +11224,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp
|
|||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11132,11 +11238,13 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
|
|||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _merge_outputs = _o->merge_outputs;
|
auto _merge_outputs = _o->merge_outputs;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateBidirectionalSequenceRNNOptions(
|
return tflite::CreateBidirectionalSequenceRNNOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_time_major,
|
_time_major,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_merge_outputs);
|
_merge_outputs,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11151,6 +11259,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl
|
|||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = weights_format(); _o->weights_format = _e; }
|
{ auto _e = weights_format(); _o->weights_format = _e; }
|
||||||
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
|
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11164,11 +11273,13 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(fl
|
|||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _weights_format = _o->weights_format;
|
auto _weights_format = _o->weights_format;
|
||||||
auto _keep_num_dims = _o->keep_num_dims;
|
auto _keep_num_dims = _o->keep_num_dims;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateFullyConnectedOptions(
|
return tflite::CreateFullyConnectedOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_weights_format,
|
_weights_format,
|
||||||
_keep_num_dims);
|
_keep_num_dims,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11352,6 +11463,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
|
|||||||
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
||||||
{ auto _e = kernel_type(); _o->kernel_type = _e; }
|
{ auto _e = kernel_type(); _o->kernel_type = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11366,12 +11478,14 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
|
|||||||
auto _cell_clip = _o->cell_clip;
|
auto _cell_clip = _o->cell_clip;
|
||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
auto _kernel_type = _o->kernel_type;
|
auto _kernel_type = _o->kernel_type;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateLSTMOptions(
|
return tflite::CreateLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip,
|
_proj_clip,
|
||||||
_kernel_type);
|
_kernel_type,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11387,6 +11501,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
|
|||||||
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11401,12 +11516,14 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
|
|||||||
auto _cell_clip = _o->cell_clip;
|
auto _cell_clip = _o->cell_clip;
|
||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateUnidirectionalSequenceLSTMOptions(
|
return tflite::CreateUnidirectionalSequenceLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip,
|
_proj_clip,
|
||||||
_time_major);
|
_time_major,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11423,6 +11540,7 @@ inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTM
|
|||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
||||||
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
|
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11438,13 +11556,15 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
|
|||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
auto _merge_outputs = _o->merge_outputs;
|
auto _merge_outputs = _o->merge_outputs;
|
||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
|
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
||||||
return tflite::CreateBidirectionalSequenceLSTMOptions(
|
return tflite::CreateBidirectionalSequenceLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip,
|
_proj_clip,
|
||||||
_merge_outputs,
|
_merge_outputs,
|
||||||
_time_major);
|
_time_major,
|
||||||
|
_asymmetric_quantize_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
Loading…
Reference in New Issue
Block a user