Adds asymmetric quantized inputs for hybrid ops in future models.
PiperOrigin-RevId: 303262193 Change-Id: I13e2bddee0e9bf10af9d5911d004ca31be430401
This commit is contained in:
parent
857f0c9557
commit
e8dbf1de1a
@ -124,33 +124,21 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
int rank;
|
int rank;
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
|
|
||||||
// Parameter for SVDF version 4.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteSVDFParams;
|
} TfLiteSVDFParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
|
|
||||||
// Parameter for RNN version 3.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteRNNParams;
|
} TfLiteRNNParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
bool time_major;
|
bool time_major;
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
|
|
||||||
// Parameter for Sequence RNN version 3.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteSequenceRNNParams;
|
} TfLiteSequenceRNNParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
bool time_major;
|
bool time_major;
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
bool merge_outputs;
|
bool merge_outputs;
|
||||||
|
|
||||||
// Parameter for Bidirectional RNN verison 3.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteBidirectionalSequenceRNNParams;
|
} TfLiteBidirectionalSequenceRNNParams;
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
@ -170,11 +158,6 @@ typedef struct {
|
|||||||
// tensors are the same. Furthermore, all but the last dimension of the input
|
// tensors are the same. Furthermore, all but the last dimension of the input
|
||||||
// and output shapes will be equal.
|
// and output shapes will be equal.
|
||||||
bool keep_num_dims;
|
bool keep_num_dims;
|
||||||
|
|
||||||
// Parameters for FullyConnected version 7 or above.
|
|
||||||
// If set to true and the weights are quantized, then non constant inputs
|
|
||||||
// are quantized at evaluation time with asymmetric quantization.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteFullyConnectedParams;
|
} TfLiteFullyConnectedParams;
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
@ -245,9 +228,6 @@ typedef struct {
|
|||||||
// Parameters for LSTM version 2.
|
// Parameters for LSTM version 2.
|
||||||
// kTfLiteLSTMBasicKernel is only supported in version 2 or above.
|
// kTfLiteLSTMBasicKernel is only supported in version 2 or above.
|
||||||
TfLiteLSTMKernelType kernel_type;
|
TfLiteLSTMKernelType kernel_type;
|
||||||
|
|
||||||
// Parameters for LSTM version 4.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteLSTMParams;
|
} TfLiteLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -258,9 +238,6 @@ typedef struct {
|
|||||||
|
|
||||||
// If set to true then the first dimension is time, otherwise batch.
|
// If set to true then the first dimension is time, otherwise batch.
|
||||||
bool time_major;
|
bool time_major;
|
||||||
|
|
||||||
// Parameter for unidirectional sequence RNN version 3.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteUnidirectionalSequenceLSTMParams;
|
} TfLiteUnidirectionalSequenceLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -276,10 +253,6 @@ typedef struct {
|
|||||||
// Parameters supported by version 2:
|
// Parameters supported by version 2:
|
||||||
// If set to true then the first dimension is time, otherwise batch.
|
// If set to true then the first dimension is time, otherwise batch.
|
||||||
bool time_major;
|
bool time_major;
|
||||||
|
|
||||||
// Parameters supported by version 4:
|
|
||||||
// If set to true, then hybrid ops use asymmetric quantization for inputs.
|
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
} TfLiteBidirectionalSequenceLSTMParams;
|
} TfLiteBidirectionalSequenceLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -269,8 +269,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->rank = svdf_params->rank();
|
params->rank = svdf_params->rank();
|
||||||
params->activation =
|
params->activation =
|
||||||
parse_activation(svdf_params->fused_activation_function());
|
parse_activation(svdf_params->fused_activation_function());
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
svdf_params->asymmetric_quantize_inputs();
|
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -282,8 +280,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->activation =
|
params->activation =
|
||||||
parse_activation(sequence_rnn_params->fused_activation_function());
|
parse_activation(sequence_rnn_params->fused_activation_function());
|
||||||
params->time_major = sequence_rnn_params->time_major();
|
params->time_major = sequence_rnn_params->time_major();
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
sequence_rnn_params->asymmetric_quantize_inputs();
|
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -297,8 +293,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
bidi_sequence_rnn_params->fused_activation_function());
|
bidi_sequence_rnn_params->fused_activation_function());
|
||||||
params->time_major = bidi_sequence_rnn_params->time_major();
|
params->time_major = bidi_sequence_rnn_params->time_major();
|
||||||
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
|
params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
bidi_sequence_rnn_params->asymmetric_quantize_inputs();
|
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -308,8 +302,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
|
if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) {
|
||||||
params->activation =
|
params->activation =
|
||||||
parse_activation(rnn_params->fused_activation_function());
|
parse_activation(rnn_params->fused_activation_function());
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
rnn_params->asymmetric_quantize_inputs();
|
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -331,8 +323,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->activation = parse_activation(
|
params->activation = parse_activation(
|
||||||
fully_connected_params->fused_activation_function());
|
fully_connected_params->fused_activation_function());
|
||||||
params->keep_num_dims = fully_connected_params->keep_num_dims();
|
params->keep_num_dims = fully_connected_params->keep_num_dims();
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
fully_connected_params->asymmetric_quantize_inputs();
|
|
||||||
switch (fully_connected_params->weights_format()) {
|
switch (fully_connected_params->weights_format()) {
|
||||||
case FullyConnectedOptionsWeightsFormat_DEFAULT:
|
case FullyConnectedOptionsWeightsFormat_DEFAULT:
|
||||||
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
|
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
|
||||||
@ -450,8 +440,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
lstm_params->kernel_type());
|
lstm_params->kernel_type());
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
lstm_params->asymmetric_quantize_inputs();
|
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_REPORT_ERROR(error_reporter,
|
TF_LITE_REPORT_ERROR(error_reporter,
|
||||||
"No valid LSTM builtin options exist");
|
"No valid LSTM builtin options exist");
|
||||||
@ -470,8 +458,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->cell_clip = seq_lstm_params->cell_clip();
|
params->cell_clip = seq_lstm_params->cell_clip();
|
||||||
params->proj_clip = seq_lstm_params->proj_clip();
|
params->proj_clip = seq_lstm_params->proj_clip();
|
||||||
params->time_major = seq_lstm_params->time_major();
|
params->time_major = seq_lstm_params->time_major();
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
seq_lstm_params->asymmetric_quantize_inputs();
|
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
@ -487,8 +473,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
params->proj_clip = bidi_lstm_params->proj_clip();
|
params->proj_clip = bidi_lstm_params->proj_clip();
|
||||||
params->merge_outputs = bidi_lstm_params->merge_outputs();
|
params->merge_outputs = bidi_lstm_params->merge_outputs();
|
||||||
params->time_major = bidi_lstm_params->time_major();
|
params->time_major = bidi_lstm_params->time_major();
|
||||||
params->asymmetric_quantize_inputs =
|
|
||||||
bidi_lstm_params->asymmetric_quantize_inputs();
|
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
|
@ -26,15 +26,6 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace rnn {
|
namespace rnn {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
int scratch_tensor_index;
|
|
||||||
bool compute_row_sums = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
constexpr int kWeightsTensor = 1;
|
constexpr int kWeightsTensor = 1;
|
||||||
constexpr int kRecurrentWeightsTensor = 2;
|
constexpr int kRecurrentWeightsTensor = 2;
|
||||||
@ -45,14 +36,13 @@ constexpr int kHiddenStateTensor = 4;
|
|||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* scratch_tensor_index = new int;
|
||||||
context->AddTensors(context, /*tensors_to_add=*/6,
|
context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
|
||||||
&op_data->scratch_tensor_index);
|
return scratch_tensor_index;
|
||||||
return op_data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<int*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -99,11 +89,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Allocate temporary tensors to store quantized values of input and
|
// Allocate temporary tensors to store quantized values of input and
|
||||||
// hidden_state tensors.
|
// hidden_state tensors.
|
||||||
if (is_hybrid) {
|
if (is_hybrid) {
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||||
op_data->compute_row_sums = true;
|
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
node->temporaries = TfLiteIntArrayCreate(6);
|
node->temporaries = TfLiteIntArrayCreate(3);
|
||||||
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
node->temporaries->data[0] = *scratch_tensor_index;
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
input_quantized->type = input_weights->type;
|
input_quantized->type = input_weights->type;
|
||||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||||
@ -112,7 +101,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
||||||
input_quantized_size));
|
input_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
node->temporaries->data[1] = *scratch_tensor_index + 1;
|
||||||
TfLiteTensor* hidden_state_quantized =
|
TfLiteTensor* hidden_state_quantized =
|
||||||
GetTemporary(context, node, /*index=*/1);
|
GetTemporary(context, node, /*index=*/1);
|
||||||
hidden_state_quantized->type = input_weights->type;
|
hidden_state_quantized->type = input_weights->type;
|
||||||
@ -125,7 +114,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->ResizeTensor(context, hidden_state_quantized,
|
context->ResizeTensor(context, hidden_state_quantized,
|
||||||
hidden_state_quantized_size));
|
hidden_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
node->temporaries->data[2] = *scratch_tensor_index + 2;
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||||
@ -136,43 +125,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
|
|
||||||
accum_scratch->type = kTfLiteInt32;
|
|
||||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
|
||||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
|
||||||
accum_scratch_dims)) {
|
|
||||||
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
|
||||||
accum_scratch_size->data[0] = accum_scratch_dims[0];
|
|
||||||
accum_scratch_size->data[1] = accum_scratch_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
|
||||||
accum_scratch_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
|
||||||
zero_points->type = kTfLiteInt32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
int zero_points_dims[1] = {batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = batch_size;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
|
||||||
row_sums->type = kTfLiteInt32;
|
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int row_sums_dims[2] = {2, num_units};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
|
||||||
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
row_sums_size->data[0] = row_sums_dims[0];
|
|
||||||
row_sums_size->data[1] = row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(
|
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,9 +165,7 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
|
|||||||
TfLiteTensor* input_scratch,
|
TfLiteTensor* input_scratch,
|
||||||
TfLiteTensor* hidden_state_scratch,
|
TfLiteTensor* hidden_state_scratch,
|
||||||
TfLiteTensor* scaling_factors,
|
TfLiteTensor* scaling_factors,
|
||||||
TfLiteTensor* hidden_state, TfLiteTensor* output,
|
TfLiteTensor* hidden_state, TfLiteTensor* output) {
|
||||||
TfLiteTensor* zero_points, TfLiteTensor* accum_scratch,
|
|
||||||
TfLiteTensor* row_sums, bool* compute_row_sums) {
|
|
||||||
const int batch_size = input->dims->data[0];
|
const int batch_size = input->dims->data[0];
|
||||||
const int num_units = input_weights->dims->data[0];
|
const int num_units = input_weights->dims->data[0];
|
||||||
const int input_size = input->dims->data[1];
|
const int input_size = input->dims->data[1];
|
||||||
@ -238,34 +190,26 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input,
|
|||||||
int8_t* quantized_hidden_state_ptr =
|
int8_t* quantized_hidden_state_ptr =
|
||||||
GetTensorData<int8_t>(hidden_state_scratch);
|
GetTensorData<int8_t>(hidden_state_scratch);
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
|
|
||||||
int32_t* zero_points_ptr = nullptr;
|
|
||||||
int32_t* row_sums_ptr = nullptr;
|
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
|
||||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
|
||||||
}
|
|
||||||
kernel_utils::RnnBatchStep(
|
kernel_utils::RnnBatchStep(
|
||||||
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
||||||
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
||||||
num_units, batch_size, output_batch_leading_dim, params->activation,
|
num_units, batch_size, output_batch_leading_dim, params->activation,
|
||||||
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
hidden_state_ptr_batch, output_ptr_batch,
|
hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr,
|
|
||||||
row_sums_ptr, compute_row_sums);
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_weights =
|
const TfLiteTensor* recurrent_weights =
|
||||||
GetInput(context, node, kRecurrentWeightsTensor);
|
GetInput(context, node, kRecurrentWeightsTensor);
|
||||||
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
|
||||||
TfLiteTensor* hidden_state =
|
TfLiteTensor* hidden_state =
|
||||||
&context->tensors[node->inputs->data[kHiddenStateTensor]];
|
GetVariableInput(context, node, kHiddenStateTensor);
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
// We already checked that weight types are consistent, so branch on one.
|
// We already checked that weight types are consistent, so branch on one.
|
||||||
@ -279,13 +223,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
||||||
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
|
|
||||||
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
||||||
input_quantized, hidden_state_quantized,
|
input_quantized, hidden_state_quantized,
|
||||||
scaling_factors, hidden_state, output, zero_points,
|
scaling_factors, hidden_state, output);
|
||||||
accum_scratch, row_sums, &op_data->compute_row_sums);
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d not currently supported.",
|
context->ReportError(context, "Type %d not currently supported.",
|
||||||
|
@ -175,8 +175,7 @@ class RNNOpModel : public SingleOpModel {
|
|||||||
public:
|
public:
|
||||||
RNNOpModel(int batches, int units, int size,
|
RNNOpModel(int batches, int units, int size,
|
||||||
const TensorType& weights = TensorType_FLOAT32,
|
const TensorType& weights = TensorType_FLOAT32,
|
||||||
const TensorType& recurrent_weights = TensorType_FLOAT32,
|
const TensorType& recurrent_weights = TensorType_FLOAT32)
|
||||||
bool asymmetric_quantize_inputs = false)
|
|
||||||
: batches_(batches), units_(units), input_size_(size) {
|
: batches_(batches), units_(units), input_size_(size) {
|
||||||
input_ = AddInput(TensorType_FLOAT32);
|
input_ = AddInput(TensorType_FLOAT32);
|
||||||
weights_ = AddInput(weights);
|
weights_ = AddInput(weights);
|
||||||
@ -184,10 +183,9 @@ class RNNOpModel : public SingleOpModel {
|
|||||||
bias_ = AddInput(TensorType_FLOAT32);
|
bias_ = AddInput(TensorType_FLOAT32);
|
||||||
hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
SetBuiltinOp(BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
|
SetBuiltinOp(
|
||||||
CreateRNNOptions(builder_, ActivationFunctionType_RELU,
|
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
|
||||||
asymmetric_quantize_inputs)
|
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
|
||||||
.Union());
|
|
||||||
BuildInterpreter({{batches_, input_size_}, // input tensor
|
BuildInterpreter({{batches_, input_size_}, // input tensor
|
||||||
{units_, input_size_}, // weights tensor
|
{units_, input_size_}, // weights tensor
|
||||||
{units_, units_}, // recurrent weights tensor
|
{units_, units_}, // recurrent weights tensor
|
||||||
@ -235,10 +233,8 @@ class RNNOpModel : public SingleOpModel {
|
|||||||
// The hybrid model has quantized weights and recurrent_weights.
|
// The hybrid model has quantized weights and recurrent_weights.
|
||||||
class HybridRNNOpModel : public RNNOpModel {
|
class HybridRNNOpModel : public RNNOpModel {
|
||||||
public:
|
public:
|
||||||
HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type,
|
HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type)
|
||||||
bool asymmetric_quantize_inputs)
|
: RNNOpModel(batches, units, size, tensor_type, tensor_type) {
|
||||||
: RNNOpModel(batches, units, size, tensor_type, tensor_type,
|
|
||||||
asymmetric_quantize_inputs) {
|
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,10 +282,8 @@ TEST(RnnOpTest, BlackBoxTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class HybridRnnOpTest : public ::testing::TestWithParam<bool> {};
|
TEST(HybridRnnOpTest, BlackBoxTestUint8) {
|
||||||
|
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8);
|
||||||
TEST_P(HybridRnnOpTest, BlackBoxTestUint8) {
|
|
||||||
HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8, GetParam());
|
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -316,8 +310,8 @@ TEST_P(HybridRnnOpTest, BlackBoxTestUint8) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HybridRnnOpTest, BlackBoxTestInt8) {
|
TEST(HybridRnnOpTest, BlackBoxTestInt8) {
|
||||||
HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8, GetParam());
|
HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8);
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -344,8 +338,5 @@ TEST_P(HybridRnnOpTest, BlackBoxTestInt8) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(HybridRnnOpTest, HybridRnnOpTest,
|
|
||||||
::testing::ValuesIn({false, true}));
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -139,28 +139,18 @@ enum TemporaryTensor {
|
|||||||
kProductScalingFactors = 8,
|
kProductScalingFactors = 8,
|
||||||
kRecoveredCellWeights = 9,
|
kRecoveredCellWeights = 9,
|
||||||
kAccumScratchBuffer = 10,
|
kAccumScratchBuffer = 10,
|
||||||
kZeroPoints = 11,
|
kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input.
|
||||||
kFwRowSums = 12,
|
kNumTemporaryTensors
|
||||||
kBwRowSums = 13,
|
|
||||||
kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input.
|
|
||||||
kNumTemporaryTensors = 15
|
|
||||||
};
|
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
int scratch_tensor_index;
|
|
||||||
bool compute_fw_row_sums = false;
|
|
||||||
bool compute_bw_row_sums = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* scratch_tensor_index = new int;
|
||||||
context->AddTensors(context, kNumTemporaryTensors,
|
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
|
||||||
&op_data->scratch_tensor_index);
|
return scratch_tensor_index;
|
||||||
return op_data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<int*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that input tensor dimensions matches with each other.
|
// Check that input tensor dimensions matches with each other.
|
||||||
@ -395,7 +385,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
|||||||
// Resize the output and scratch tensors based on the sizes of the input
|
// Resize the output and scratch tensors based on the sizes of the input
|
||||||
// tensors. Also check that the size of the input tensors match each other.
|
// tensors. Also check that the size of the input tensors match each other.
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||||
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
||||||
node->builtin_data);
|
node->builtin_data);
|
||||||
|
|
||||||
@ -532,7 +522,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
|
node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
|
||||||
}
|
}
|
||||||
// Create a scratch buffer tensor.
|
// Create a scratch buffer tensor.
|
||||||
node->temporaries->data[kFwScratchBuffer] = op_data->scratch_tensor_index;
|
node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
|
||||||
TfLiteTensor* fw_scratch_buffer =
|
TfLiteTensor* fw_scratch_buffer =
|
||||||
GetTemporary(context, node, kFwScratchBuffer);
|
GetTemporary(context, node, kFwScratchBuffer);
|
||||||
fw_scratch_buffer->type = input->type;
|
fw_scratch_buffer->type = input->type;
|
||||||
@ -591,7 +581,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Create a scratch buffer tensor.
|
// Create a scratch buffer tensor.
|
||||||
node->temporaries->data[kBwScratchBuffer] =
|
node->temporaries->data[kBwScratchBuffer] =
|
||||||
op_data->scratch_tensor_index + kBwScratchBuffer;
|
*(scratch_tensor_index) + kBwScratchBuffer;
|
||||||
TfLiteTensor* bw_scratch_buffer =
|
TfLiteTensor* bw_scratch_buffer =
|
||||||
GetTemporary(context, node, kBwScratchBuffer);
|
GetTemporary(context, node, kBwScratchBuffer);
|
||||||
bw_scratch_buffer->type = input->type;
|
bw_scratch_buffer->type = input->type;
|
||||||
@ -616,13 +606,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
|
||||||
bw_scratch_buffer_size));
|
bw_scratch_buffer_size));
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
// Compute the row sums for cached zero_point offset calculation.
|
|
||||||
op_data->compute_fw_row_sums = true;
|
|
||||||
op_data->compute_bw_row_sums = true;
|
|
||||||
// Allocate temporary tensors to store quantized values of input, aux_input
|
// Allocate temporary tensors to store quantized values of input, aux_input
|
||||||
// (if present), activation_state and cell_state tensors.
|
// (if present), activation_state and cell_state tensors.
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
op_data->scratch_tensor_index + kInputQuantized;
|
*scratch_tensor_index + kInputQuantized;
|
||||||
TfLiteTensor* input_quantized =
|
TfLiteTensor* input_quantized =
|
||||||
GetTemporary(context, node, kInputQuantized);
|
GetTemporary(context, node, kInputQuantized);
|
||||||
input_quantized->type = fw_input_to_output_weights->type;
|
input_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -634,7 +621,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kFwActivationStateQuantized] =
|
node->temporaries->data[kFwActivationStateQuantized] =
|
||||||
op_data->scratch_tensor_index + kFwActivationStateQuantized;
|
*scratch_tensor_index + kFwActivationStateQuantized;
|
||||||
TfLiteTensor* fw_activation_state_quantized =
|
TfLiteTensor* fw_activation_state_quantized =
|
||||||
GetTemporary(context, node, kFwActivationStateQuantized);
|
GetTemporary(context, node, kFwActivationStateQuantized);
|
||||||
fw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
fw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -648,7 +635,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_activation_state_quantized_size));
|
fw_activation_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kBwActivationStateQuantized] =
|
node->temporaries->data[kBwActivationStateQuantized] =
|
||||||
op_data->scratch_tensor_index + kBwActivationStateQuantized;
|
*scratch_tensor_index + kBwActivationStateQuantized;
|
||||||
TfLiteTensor* bw_activation_state_quantized =
|
TfLiteTensor* bw_activation_state_quantized =
|
||||||
GetTemporary(context, node, kBwActivationStateQuantized);
|
GetTemporary(context, node, kBwActivationStateQuantized);
|
||||||
bw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
bw_activation_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -662,7 +649,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bw_activation_state_quantized_size));
|
bw_activation_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kFwCellStateQuantized] =
|
node->temporaries->data[kFwCellStateQuantized] =
|
||||||
op_data->scratch_tensor_index + kFwCellStateQuantized;
|
*scratch_tensor_index + kFwCellStateQuantized;
|
||||||
TfLiteTensor* fw_cell_state_quantized =
|
TfLiteTensor* fw_cell_state_quantized =
|
||||||
GetTemporary(context, node, kFwCellStateQuantized);
|
GetTemporary(context, node, kFwCellStateQuantized);
|
||||||
fw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
fw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -676,7 +663,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_cell_state_quantized_size));
|
fw_cell_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kBwCellStateQuantized] =
|
node->temporaries->data[kBwCellStateQuantized] =
|
||||||
op_data->scratch_tensor_index + kBwCellStateQuantized;
|
*scratch_tensor_index + kBwCellStateQuantized;
|
||||||
TfLiteTensor* bw_cell_state_quantized =
|
TfLiteTensor* bw_cell_state_quantized =
|
||||||
GetTemporary(context, node, kBwCellStateQuantized);
|
GetTemporary(context, node, kBwCellStateQuantized);
|
||||||
bw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
bw_cell_state_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -696,7 +683,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// different matrices (which requires multiplying the scaling factors with
|
// different matrices (which requires multiplying the scaling factors with
|
||||||
// the scaling factor of the matrix).
|
// the scaling factor of the matrix).
|
||||||
node->temporaries->data[kScalingFactors] =
|
node->temporaries->data[kScalingFactors] =
|
||||||
op_data->scratch_tensor_index + kScalingFactors;
|
*scratch_tensor_index + kScalingFactors;
|
||||||
TfLiteTensor* scaling_factors =
|
TfLiteTensor* scaling_factors =
|
||||||
GetTemporary(context, node, kScalingFactors);
|
GetTemporary(context, node, kScalingFactors);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
@ -709,7 +696,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kProductScalingFactors] =
|
node->temporaries->data[kProductScalingFactors] =
|
||||||
op_data->scratch_tensor_index + kProductScalingFactors;
|
*scratch_tensor_index + kProductScalingFactors;
|
||||||
TfLiteTensor* prod_scaling_factors =
|
TfLiteTensor* prod_scaling_factors =
|
||||||
GetTemporary(context, node, kProductScalingFactors);
|
GetTemporary(context, node, kProductScalingFactors);
|
||||||
prod_scaling_factors->type = kTfLiteFloat32;
|
prod_scaling_factors->type = kTfLiteFloat32;
|
||||||
@ -726,7 +713,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Allocate a temporary tensor to store the recovered cell weights. Since
|
// Allocate a temporary tensor to store the recovered cell weights. Since
|
||||||
// this is used for diagonal matrices, only need to store n_cell values.
|
// this is used for diagonal matrices, only need to store n_cell values.
|
||||||
node->temporaries->data[kRecoveredCellWeights] =
|
node->temporaries->data[kRecoveredCellWeights] =
|
||||||
op_data->scratch_tensor_index + kRecoveredCellWeights;
|
*scratch_tensor_index + kRecoveredCellWeights;
|
||||||
TfLiteTensor* recovered_cell_weights =
|
TfLiteTensor* recovered_cell_weights =
|
||||||
GetTemporary(context, node, kRecoveredCellWeights);
|
GetTemporary(context, node, kRecoveredCellWeights);
|
||||||
recovered_cell_weights->type = kTfLiteFloat32;
|
recovered_cell_weights->type = kTfLiteFloat32;
|
||||||
@ -743,7 +730,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Allocate a temporary tensor to store the accumulated int32 values.
|
// Allocate a temporary tensor to store the accumulated int32 values.
|
||||||
node->temporaries->data[kAccumScratchBuffer] =
|
node->temporaries->data[kAccumScratchBuffer] =
|
||||||
op_data->scratch_tensor_index + kAccumScratchBuffer;
|
*scratch_tensor_index + kAccumScratchBuffer;
|
||||||
TfLiteTensor* accum_scratch =
|
TfLiteTensor* accum_scratch =
|
||||||
GetTemporary(context, node, kAccumScratchBuffer);
|
GetTemporary(context, node, kAccumScratchBuffer);
|
||||||
accum_scratch->type = kTfLiteInt32;
|
accum_scratch->type = kTfLiteInt32;
|
||||||
@ -763,72 +750,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate temporary tensors for storing zero-points.
|
|
||||||
node->temporaries->data[kZeroPoints] =
|
|
||||||
op_data->scratch_tensor_index + kZeroPoints;
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
|
||||||
zero_points->type = kTfLiteFloat32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = n_batch;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate temporary tensors for caching row sums for hybrid zero-point
|
|
||||||
// calculations.
|
|
||||||
int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
|
|
||||||
if (has_aux_input) {
|
|
||||||
fw_row_sums_rows += fw_use_cifg ? 3 : 4;
|
|
||||||
}
|
|
||||||
const TfLiteTensor* fw_projection_weights =
|
|
||||||
GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
|
|
||||||
if (fw_projection_weights != nullptr) {
|
|
||||||
fw_row_sums_rows += ceil(n_fw_output / n_fw_cell);
|
|
||||||
}
|
|
||||||
node->temporaries->data[kFwRowSums] =
|
|
||||||
op_data->scratch_tensor_index + kFwRowSums;
|
|
||||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
|
||||||
fw_row_sums->type = kTfLiteInt32;
|
|
||||||
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
|
|
||||||
TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
|
|
||||||
fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
|
|
||||||
fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
|
|
||||||
fw_hybrid_scratch_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
|
|
||||||
if (has_aux_input) {
|
|
||||||
bw_row_sums_rows += bw_use_cifg ? 3 : 4;
|
|
||||||
}
|
|
||||||
const TfLiteTensor* bw_projection_weights =
|
|
||||||
GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
|
|
||||||
if (bw_projection_weights != nullptr) {
|
|
||||||
bw_row_sums_rows += ceil(n_bw_output / n_bw_cell);
|
|
||||||
}
|
|
||||||
node->temporaries->data[kBwRowSums] =
|
|
||||||
op_data->scratch_tensor_index + kBwRowSums;
|
|
||||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
|
||||||
bw_row_sums->type = kTfLiteInt32;
|
|
||||||
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
|
|
||||||
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
|
|
||||||
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
|
|
||||||
bw_row_sums_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only allocate a temporary tensor for quantized auxiliary input if we are
|
// Only allocate a temporary tensor for quantized auxiliary input if we are
|
||||||
// actually going to use it.
|
// actually going to use it.
|
||||||
if (has_aux_input) {
|
if (has_aux_input) {
|
||||||
node->temporaries->data[kAuxInputQuantized] =
|
node->temporaries->data[kAuxInputQuantized] =
|
||||||
op_data->scratch_tensor_index + kAuxInputQuantized;
|
*scratch_tensor_index + kAuxInputQuantized;
|
||||||
TfLiteTensor* aux_input_quantized =
|
TfLiteTensor* aux_input_quantized =
|
||||||
GetTemporary(context, node, kAuxInputQuantized);
|
GetTemporary(context, node, kAuxInputQuantized);
|
||||||
aux_input_quantized->type = fw_input_to_output_weights->type;
|
aux_input_quantized->type = fw_input_to_output_weights->type;
|
||||||
@ -849,7 +775,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
||||||
node->builtin_data);
|
node->builtin_data);
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
|
||||||
// Input tensor.
|
// Input tensor.
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
|
||||||
@ -983,8 +909,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Populate a TfLiteLSTMParams struct for the evaluation functions.
|
// Populate a TfLiteLSTMParams struct for the evaluation functions.
|
||||||
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
|
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
|
||||||
params->proj_clip, kTfLiteLSTMFullKernel,
|
params->proj_clip, kTfLiteLSTMFullKernel};
|
||||||
params->asymmetric_quantize_inputs};
|
|
||||||
|
|
||||||
const int bw_output_offset =
|
const int bw_output_offset =
|
||||||
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
|
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
|
||||||
@ -1078,11 +1003,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
: nullptr;
|
: nullptr;
|
||||||
TfLiteTensor* accum_scratch =
|
TfLiteTensor* accum_scratch =
|
||||||
GetTemporary(context, node, kAccumScratchBuffer);
|
GetTemporary(context, node, kAccumScratchBuffer);
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
|
||||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
|
||||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
|
||||||
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
|
||||||
const int bw_row_sums_size = bw_row_sums->dims->data[0];
|
|
||||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||||
fw_input_to_cell_weights, fw_input_to_output_weights,
|
fw_input_to_cell_weights, fw_input_to_output_weights,
|
||||||
@ -1104,8 +1025,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||||
fw_activation_state_quantized, fw_cell_state_quantized,
|
fw_activation_state_quantized, fw_cell_state_quantized,
|
||||||
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
|
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
|
||||||
zero_points, fw_row_sums, fw_row_sums_size,
|
|
||||||
&op_data->compute_fw_row_sums,
|
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||||
|
|
||||||
@ -1130,8 +1049,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||||
bw_activation_state_quantized, bw_cell_state_quantized,
|
bw_activation_state_quantized, bw_cell_state_quantized,
|
||||||
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
|
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
|
||||||
zero_points, bw_row_sums, bw_row_sums_size,
|
|
||||||
&op_data->compute_bw_row_sums,
|
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
@ -40,8 +40,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
bool use_projection_bias, bool merge_outputs,
|
bool use_projection_bias, bool merge_outputs,
|
||||||
bool use_aux_input, float cell_clip, float proj_clip,
|
bool use_aux_input, float cell_clip, float proj_clip,
|
||||||
bool quantize_weights, bool time_major,
|
bool quantize_weights, bool time_major,
|
||||||
const std::vector<std::vector<int>>& input_shapes,
|
const std::vector<std::vector<int>>& input_shapes)
|
||||||
bool asymmetric_quantize_inputs = false)
|
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_fw_cell_(n_cell),
|
n_fw_cell_(n_cell),
|
||||||
@ -208,12 +207,11 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
bw_aux_input_to_output_weights_ = AddNullInput();
|
bw_aux_input_to_output_weights_ = AddNullInput();
|
||||||
}
|
}
|
||||||
|
|
||||||
SetBuiltinOp(
|
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
|
||||||
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
|
|
||||||
BuiltinOptions_BidirectionalSequenceLSTMOptions,
|
BuiltinOptions_BidirectionalSequenceLSTMOptions,
|
||||||
CreateBidirectionalSequenceLSTMOptions(
|
CreateBidirectionalSequenceLSTMOptions(
|
||||||
builder_, ActivationFunctionType_TANH, cell_clip, proj_clip,
|
builder_, ActivationFunctionType_TANH, cell_clip,
|
||||||
merge_outputs, time_major, asymmetric_quantize_inputs)
|
proj_clip, merge_outputs, time_major)
|
||||||
.Union());
|
.Union());
|
||||||
BuildInterpreter(input_shapes);
|
BuildInterpreter(input_shapes);
|
||||||
}
|
}
|
||||||
@ -426,14 +424,11 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
bool quantize_weights_;
|
bool quantize_weights_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Declare LSTMOpTest as a parameterized test.
|
// Declare LSTMOpTest as a parameterized test, where the parameter is a boolean
|
||||||
class LSTMOpTest
|
// indicating whether to use quantization or not.
|
||||||
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
|
class LSTMOpTest : public ::testing::TestWithParam<bool> {};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest,
|
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool());
|
||||||
::testing::Combine(
|
|
||||||
/*quantize_weights*/ ::testing::Bool(),
|
|
||||||
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
|
|
||||||
|
|
||||||
TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
@ -442,9 +437,7 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
auto params = GetParam();
|
const bool quantize_weights = GetParam();
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -516,8 +509,7 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
|||||||
{0}, // aux_bw_input_to_forget tensor
|
{0}, // aux_bw_input_to_forget tensor
|
||||||
{0}, // aux_bw_input_to_cell tensor
|
{0}, // aux_bw_input_to_cell tensor
|
||||||
{0}, // aux_bw_input_to_output tensor
|
{0}, // aux_bw_input_to_output tensor
|
||||||
},
|
});
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -608,9 +600,7 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
auto params = GetParam();
|
const bool quantize_weights = GetParam();
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -682,8 +672,7 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
|
|||||||
{0}, // aux_bw_input_to_forget tensor
|
{0}, // aux_bw_input_to_forget tensor
|
||||||
{0}, // aux_bw_input_to_cell tensor
|
{0}, // aux_bw_input_to_cell tensor
|
||||||
{0}, // aux_bw_input_to_output tensor
|
{0}, // aux_bw_input_to_output tensor
|
||||||
},
|
});
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -2642,9 +2631,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
auto params = GetParam();
|
const bool quantize_weights = GetParam();
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -2716,8 +2703,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
|
|||||||
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
||||||
},
|
});
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
@ -2816,9 +2802,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
|
|||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
const int n_output = 4;
|
const int n_output = 4;
|
||||||
const int sequence_length = 3;
|
const int sequence_length = 3;
|
||||||
auto params = GetParam();
|
const bool quantize_weights = GetParam();
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
|
|
||||||
BidirectionalLSTMOpModel lstm(
|
BidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||||
@ -2890,8 +2874,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
|
|||||||
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
{n_cell, n_input}, // aux_bw_input_to_forget tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
{n_cell, n_input}, // aux_bw_input_to_cell tensor
|
||||||
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
{n_cell, n_input}, // aux_bw_input_to_output tensor
|
||||||
},
|
});
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
|
|
||||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||||
-0.34550029, 0.04266912, -0.15680569,
|
-0.34550029, 0.04266912, -0.15680569,
|
||||||
|
@ -27,16 +27,6 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace bidirectional_sequence_rnn {
|
namespace bidirectional_sequence_rnn {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
int scratch_tensor_index;
|
|
||||||
bool fw_compute_row_sums = false;
|
|
||||||
bool bw_compute_row_sums = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
@ -68,23 +58,18 @@ enum TemporaryTensor {
|
|||||||
kFwHiddenStateQuantized = 1,
|
kFwHiddenStateQuantized = 1,
|
||||||
kBwHiddenStateQuantized = 2,
|
kBwHiddenStateQuantized = 2,
|
||||||
kScalingFactors = 3,
|
kScalingFactors = 3,
|
||||||
kAccumScratch = 4,
|
kAuxInputQuantized = 4,
|
||||||
kZeroPoints = 5,
|
kNumTemporaryTensors = 5
|
||||||
kFwRowSums = 6,
|
|
||||||
kBwRowSums = 7,
|
|
||||||
kAuxInputQuantized = 8,
|
|
||||||
kNumTemporaryTensors = 9
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* scratch_tensor_index = new int;
|
||||||
context->AddTensors(context, kNumTemporaryTensors,
|
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
|
||||||
&op_data->scratch_tensor_index);
|
return scratch_tensor_index;
|
||||||
return op_data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<int*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -172,9 +157,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (IsHybridOp(input, fw_input_weights)) {
|
if (IsHybridOp(input, fw_input_weights)) {
|
||||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||||
op_data->fw_compute_row_sums = true;
|
|
||||||
op_data->bw_compute_row_sums = true;
|
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (has_aux_input) {
|
if (has_aux_input) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
|
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
|
||||||
@ -184,7 +168,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
op_data->scratch_tensor_index + kInputQuantized;
|
*scratch_tensor_index + kInputQuantized;
|
||||||
TfLiteTensor* input_quantized =
|
TfLiteTensor* input_quantized =
|
||||||
GetTemporary(context, node, kInputQuantized);
|
GetTemporary(context, node, kInputQuantized);
|
||||||
input_quantized->type = fw_input_weights->type;
|
input_quantized->type = fw_input_weights->type;
|
||||||
@ -196,7 +180,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kFwHiddenStateQuantized] =
|
node->temporaries->data[kFwHiddenStateQuantized] =
|
||||||
op_data->scratch_tensor_index + kFwHiddenStateQuantized;
|
*scratch_tensor_index + kFwHiddenStateQuantized;
|
||||||
TfLiteTensor* fw_hidden_state_quantized =
|
TfLiteTensor* fw_hidden_state_quantized =
|
||||||
GetTemporary(context, node, kFwHiddenStateQuantized);
|
GetTemporary(context, node, kFwHiddenStateQuantized);
|
||||||
fw_hidden_state_quantized->type = fw_input_weights->type;
|
fw_hidden_state_quantized->type = fw_input_weights->type;
|
||||||
@ -211,7 +195,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[kBwHiddenStateQuantized] =
|
node->temporaries->data[kBwHiddenStateQuantized] =
|
||||||
op_data->scratch_tensor_index + kBwHiddenStateQuantized;
|
*scratch_tensor_index + kBwHiddenStateQuantized;
|
||||||
TfLiteTensor* bw_hidden_state_quantized =
|
TfLiteTensor* bw_hidden_state_quantized =
|
||||||
GetTemporary(context, node, kBwHiddenStateQuantized);
|
GetTemporary(context, node, kBwHiddenStateQuantized);
|
||||||
bw_hidden_state_quantized->type = fw_input_weights->type;
|
bw_hidden_state_quantized->type = fw_input_weights->type;
|
||||||
@ -227,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Allocate temporary tensors to store scaling factors of quantization.
|
// Allocate temporary tensors to store scaling factors of quantization.
|
||||||
node->temporaries->data[kScalingFactors] =
|
node->temporaries->data[kScalingFactors] =
|
||||||
op_data->scratch_tensor_index + kScalingFactors;
|
*scratch_tensor_index + kScalingFactors;
|
||||||
TfLiteTensor* scaling_factors =
|
TfLiteTensor* scaling_factors =
|
||||||
GetTemporary(context, node, kScalingFactors);
|
GetTemporary(context, node, kScalingFactors);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
@ -239,66 +223,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kAccumScratch] =
|
|
||||||
op_data->scratch_tensor_index + kAccumScratch;
|
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
|
||||||
accum_scratch->type = kTfLiteInt32;
|
|
||||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
|
||||||
int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
|
|
||||||
batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
|
||||||
accum_scratch_dims)) {
|
|
||||||
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
|
||||||
accum_scratch_size->data[0] = accum_scratch_dims[0];
|
|
||||||
accum_scratch_size->data[1] = accum_scratch_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
|
||||||
accum_scratch_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[kZeroPoints] =
|
|
||||||
op_data->scratch_tensor_index + kZeroPoints;
|
|
||||||
TfLiteTensor* zero_points =
|
|
||||||
GetTemporary(context, node, /*index=*/kZeroPoints);
|
|
||||||
zero_points->type = kTfLiteInt32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
int zero_points_dims[1] = {batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = batch_size;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
const int num_row_sums = has_aux_input ? 3 : 2;
|
|
||||||
node->temporaries->data[kFwRowSums] =
|
|
||||||
op_data->scratch_tensor_index + kFwRowSums;
|
|
||||||
TfLiteTensor* fw_row_sums =
|
|
||||||
GetTemporary(context, node, /*index=*/kFwRowSums);
|
|
||||||
fw_row_sums->type = kTfLiteInt32;
|
|
||||||
fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
|
|
||||||
TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
fw_row_sums_size->data[0] = fw_row_sums_dims[0];
|
|
||||||
fw_row_sums_size->data[1] = fw_row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
|
|
||||||
fw_row_sums_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[kBwRowSums] =
|
|
||||||
op_data->scratch_tensor_index + kBwRowSums;
|
|
||||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node,
|
|
||||||
/*index=*/kBwRowSums);
|
|
||||||
bw_row_sums->type = kTfLiteInt32;
|
|
||||||
bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
|
|
||||||
TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
bw_row_sums_size->data[0] = bw_row_sums_dims[0];
|
|
||||||
bw_row_sums_size->data[1] = bw_row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
|
|
||||||
bw_row_sums_size));
|
|
||||||
}
|
|
||||||
if (has_aux_input) {
|
if (has_aux_input) {
|
||||||
node->temporaries->data[kAuxInputQuantized] =
|
node->temporaries->data[kAuxInputQuantized] =
|
||||||
op_data->scratch_tensor_index + kAuxInputQuantized;
|
*scratch_tensor_index + kAuxInputQuantized;
|
||||||
TfLiteTensor* aux_input_quantized =
|
TfLiteTensor* aux_input_quantized =
|
||||||
GetTemporary(context, node, kAuxInputQuantized);
|
GetTemporary(context, node, kAuxInputQuantized);
|
||||||
aux_input_quantized->type = fw_input_weights->type;
|
aux_input_quantized->type = fw_input_weights->type;
|
||||||
@ -490,10 +418,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
|
||||||
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
|
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
|
||||||
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
|
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
|
||||||
TfLiteTensor* bw_output, TfLiteTensor* zero_points,
|
TfLiteTensor* bw_output) {
|
||||||
TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
|
|
||||||
TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
|
|
||||||
bool* bw_compute_row_sums) {
|
|
||||||
const bool time_major = params->time_major;
|
const bool time_major = params->time_major;
|
||||||
const int batch_size =
|
const int batch_size =
|
||||||
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
||||||
@ -539,20 +464,11 @@ TfLiteStatus EvalHybrid(
|
|||||||
int8_t* bw_quantized_hidden_state_ptr =
|
int8_t* bw_quantized_hidden_state_ptr =
|
||||||
GetTensorData<int8_t>(bw_hidden_state_quantized);
|
GetTensorData<int8_t>(bw_hidden_state_quantized);
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
|
|
||||||
int32_t* zero_points_ptr = nullptr;
|
|
||||||
int32_t* fw_row_sums_ptr = nullptr;
|
|
||||||
int32_t* bw_row_sums_ptr = nullptr;
|
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
|
||||||
fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
|
|
||||||
bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
|
|
||||||
}
|
|
||||||
const int fw_output_step =
|
const int fw_output_step =
|
||||||
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
|
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
|
||||||
const int bw_output_step =
|
const int bw_output_step =
|
||||||
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
|
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
for (int t = 0; t < max_time; t++) {
|
for (int t = 0; t < max_time; t++) {
|
||||||
// Forward cell.
|
// Forward cell.
|
||||||
@ -575,9 +491,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
fw_num_units, batch_size, fw_output_step, params->activation,
|
fw_num_units, batch_size, fw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
fw_hidden_state_ptr_batch, output_ptr_batch,
|
fw_hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr,
|
|
||||||
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
|
|
||||||
}
|
}
|
||||||
// Backward cell.
|
// Backward cell.
|
||||||
float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
|
float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
|
||||||
@ -602,9 +516,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
bw_num_units, batch_size, bw_output_step, params->activation,
|
bw_num_units, batch_size, bw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
bw_hidden_state_ptr_batch, output_ptr_batch,
|
bw_hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr,
|
|
||||||
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -633,9 +545,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
|
fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
fw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
fw_hidden_state_ptr_batch, output_ptr_batch,
|
fw_hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr,
|
|
||||||
accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
|
|
||||||
}
|
}
|
||||||
// Backward cell.
|
// Backward cell.
|
||||||
float* bw_hidden_state_ptr_batch =
|
float* bw_hidden_state_ptr_batch =
|
||||||
@ -664,9 +574,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
|
bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
|
||||||
quantized_input_ptr, aux_quantized_input_ptr,
|
quantized_input_ptr, aux_quantized_input_ptr,
|
||||||
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
bw_quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
bw_hidden_state_ptr_batch, output_ptr_batch,
|
bw_hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr,
|
|
||||||
accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -748,23 +656,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, kBwHiddenStateQuantized);
|
GetTemporary(context, node, kBwHiddenStateQuantized);
|
||||||
TfLiteTensor* scaling_factors =
|
TfLiteTensor* scaling_factors =
|
||||||
GetTemporary(context, node, kScalingFactors);
|
GetTemporary(context, node, kScalingFactors);
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
|
||||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
|
||||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
|
||||||
TfLiteTensor* aux_input_quantized =
|
TfLiteTensor* aux_input_quantized =
|
||||||
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
|
||||||
return EvalHybrid(
|
return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights,
|
||||||
input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
|
fw_bias, bw_input_weights, bw_recurrent_weights,
|
||||||
bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
|
bw_bias, real_aux_input, fw_aux_input_weights,
|
||||||
fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
|
bw_aux_input_weights, params, scaling_factors,
|
||||||
input_quantized, aux_input_quantized, fw_hidden_state_quantized,
|
input_quantized, aux_input_quantized,
|
||||||
fw_hidden_state, fw_output, bw_hidden_state_quantized,
|
fw_hidden_state_quantized, fw_hidden_state, fw_output,
|
||||||
bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
|
bw_hidden_state_quantized, bw_hidden_state, bw_output);
|
||||||
bw_row_sums, &op_data->fw_compute_row_sums,
|
|
||||||
&op_data->bw_compute_row_sums);
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type not currently supported.");
|
context->ReportError(context, "Type not currently supported.");
|
||||||
|
@ -662,24 +662,20 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
|
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
|
||||||
int bw_units, int input_size, int aux_input_size,
|
int bw_units, int input_size, int aux_input_size,
|
||||||
AuxInputMode aux_input_mode, bool time_major,
|
AuxInputMode aux_input_mode, bool time_major,
|
||||||
bool merge_outputs, bool quantize_weights = false,
|
bool merge_outputs)
|
||||||
bool asymmetric_quantize_weights = false)
|
|
||||||
: batches_(batches),
|
: batches_(batches),
|
||||||
sequence_len_(sequence_len),
|
sequence_len_(sequence_len),
|
||||||
fw_units_(fw_units),
|
fw_units_(fw_units),
|
||||||
bw_units_(bw_units),
|
bw_units_(bw_units),
|
||||||
input_size_(input_size),
|
input_size_(input_size),
|
||||||
aux_input_size_(aux_input_size),
|
aux_input_size_(aux_input_size) {
|
||||||
quantize_weights_(quantize_weights) {
|
|
||||||
const TensorType tensor_type =
|
|
||||||
quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32;
|
|
||||||
input_ = AddInput(TensorType_FLOAT32);
|
input_ = AddInput(TensorType_FLOAT32);
|
||||||
fw_weights_ = AddInput(tensor_type);
|
fw_weights_ = AddInput(TensorType_FLOAT32);
|
||||||
fw_recurrent_weights_ = AddInput(tensor_type);
|
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
|
||||||
fw_bias_ = AddInput(TensorType_FLOAT32);
|
fw_bias_ = AddInput(TensorType_FLOAT32);
|
||||||
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
||||||
bw_weights_ = AddInput(tensor_type);
|
bw_weights_ = AddInput(TensorType_FLOAT32);
|
||||||
bw_recurrent_weights_ = AddInput(tensor_type);
|
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
|
||||||
bw_bias_ = AddInput(TensorType_FLOAT32);
|
bw_bias_ = AddInput(TensorType_FLOAT32);
|
||||||
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
|
||||||
|
|
||||||
@ -701,8 +697,8 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (aux_input_mode == AuxInputMode::kCrossLinking) {
|
if (aux_input_mode == AuxInputMode::kCrossLinking) {
|
||||||
aux_fw_weights_ = AddInput(tensor_type);
|
aux_fw_weights_ = AddInput(TensorType_FLOAT32);
|
||||||
aux_bw_weights_ = AddInput(tensor_type);
|
aux_bw_weights_ = AddInput(TensorType_FLOAT32);
|
||||||
|
|
||||||
aux_fw_weights_shape = {fw_units, aux_input_size_};
|
aux_fw_weights_shape = {fw_units, aux_input_size_};
|
||||||
aux_bw_weights_shape = {bw_units, aux_input_size_};
|
aux_bw_weights_shape = {bw_units, aux_input_size_};
|
||||||
@ -716,11 +712,11 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
bw_output_ = AddOutput(TensorType_FLOAT32);
|
bw_output_ = AddOutput(TensorType_FLOAT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
|
SetBuiltinOp(
|
||||||
|
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
|
||||||
BuiltinOptions_BidirectionalSequenceRNNOptions,
|
BuiltinOptions_BidirectionalSequenceRNNOptions,
|
||||||
CreateBidirectionalSequenceRNNOptions(
|
CreateBidirectionalSequenceRNNOptions(
|
||||||
builder_, time_major, ActivationFunctionType_RELU,
|
builder_, time_major, ActivationFunctionType_RELU, merge_outputs)
|
||||||
merge_outputs, asymmetric_quantize_weights)
|
|
||||||
.Union());
|
.Union());
|
||||||
|
|
||||||
BuildInterpreter({
|
BuildInterpreter({
|
||||||
@ -748,36 +744,20 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetFwWeights(const std::vector<float>& f) {
|
void SetFwWeights(const std::vector<float>& f) {
|
||||||
if (quantize_weights_) {
|
|
||||||
SymmetricQuantizeAndPopulate(fw_weights_, f);
|
|
||||||
} else {
|
|
||||||
PopulateTensor(fw_weights_, f);
|
PopulateTensor(fw_weights_, f);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void SetBwWeights(const std::vector<float>& f) {
|
void SetBwWeights(const std::vector<float>& f) {
|
||||||
if (quantize_weights_) {
|
|
||||||
SymmetricQuantizeAndPopulate(bw_weights_, f);
|
|
||||||
} else {
|
|
||||||
PopulateTensor(bw_weights_, f);
|
PopulateTensor(bw_weights_, f);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void SetFwRecurrentWeights(const std::vector<float>& f) {
|
void SetFwRecurrentWeights(const std::vector<float>& f) {
|
||||||
if (quantize_weights_) {
|
|
||||||
SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f);
|
|
||||||
} else {
|
|
||||||
PopulateTensor(fw_recurrent_weights_, f);
|
PopulateTensor(fw_recurrent_weights_, f);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void SetBwRecurrentWeights(const std::vector<float>& f) {
|
void SetBwRecurrentWeights(const std::vector<float>& f) {
|
||||||
if (quantize_weights_) {
|
|
||||||
SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f);
|
|
||||||
} else {
|
|
||||||
PopulateTensor(bw_recurrent_weights_, f);
|
PopulateTensor(bw_recurrent_weights_, f);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void SetInput(std::initializer_list<float> data) {
|
void SetInput(std::initializer_list<float> data) {
|
||||||
PopulateTensor(input_, data);
|
PopulateTensor(input_, data);
|
||||||
@ -792,20 +772,12 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetAuxFwWeights(const std::vector<float>& f) {
|
void SetAuxFwWeights(const std::vector<float>& f) {
|
||||||
if (quantize_weights_) {
|
|
||||||
SymmetricQuantizeAndPopulate(aux_fw_weights_, f);
|
|
||||||
} else {
|
|
||||||
PopulateTensor(aux_fw_weights_, f);
|
PopulateTensor(aux_fw_weights_, f);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void SetAuxBwWeights(const std::vector<float>& f) {
|
void SetAuxBwWeights(const std::vector<float>& f) {
|
||||||
if (quantize_weights_) {
|
|
||||||
SymmetricQuantizeAndPopulate(aux_bw_weights_, f);
|
|
||||||
} else {
|
|
||||||
PopulateTensor(aux_bw_weights_, f);
|
PopulateTensor(aux_bw_weights_, f);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
|
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
|
||||||
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
|
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
|
||||||
@ -839,31 +811,17 @@ class BidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
int bw_units_;
|
int bw_units_;
|
||||||
int input_size_;
|
int input_size_;
|
||||||
int aux_input_size_;
|
int aux_input_size_;
|
||||||
bool quantize_weights_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Declare LSTMOpTest as a parameterized test.
|
|
||||||
class BidirectionalRNNOpTest
|
|
||||||
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest,
|
|
||||||
::testing::Combine(
|
|
||||||
/*quantize_weights*/ ::testing::Bool(),
|
|
||||||
/*asymmetric_quantize_inputs*/ ::testing::Bool()));
|
|
||||||
|
|
||||||
// TODO(mirkov): add another test which directly compares to TF once TOCO
|
// TODO(mirkov): add another test which directly compares to TF once TOCO
|
||||||
// supports the conversion from dynamic_rnn with BasicRNNCell.
|
// supports the conversion from dynamic_rnn with BasicRNNCell.
|
||||||
TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
|
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
|
||||||
auto params = GetParam();
|
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*fw_units=*/16, /*bw_units=*/16,
|
/*fw_units=*/16, /*bw_units=*/16,
|
||||||
/*input_size=*/8, /*aux_input_size=*/0,
|
/*input_size=*/8, /*aux_input_size=*/0,
|
||||||
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
||||||
/*time_major=*/false,
|
/*time_major=*/false,
|
||||||
/*merge_outputs=*/false, quantize_weights,
|
/*merge_outputs=*/false);
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
rnn.SetFwWeights(weights);
|
rnn.SetFwWeights(weights);
|
||||||
rnn.SetBwWeights(weights);
|
rnn.SetBwWeights(weights);
|
||||||
rnn.SetFwBias(biases);
|
rnn.SetFwBias(biases);
|
||||||
@ -885,9 +843,7 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
|
|||||||
std::vector<float> fw_expected;
|
std::vector<float> fw_expected;
|
||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
EXPECT_THAT(rnn.GetFwOutput(),
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
||||||
ElementsAreArray(ArrayFloatNear(
|
|
||||||
fw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
|
||||||
|
|
||||||
float* golden_bw_start = rnn_golden_bw_output;
|
float* golden_bw_start = rnn_golden_bw_output;
|
||||||
float* golden_bw_end =
|
float* golden_bw_end =
|
||||||
@ -895,23 +851,17 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTest) {
|
|||||||
std::vector<float> bw_expected;
|
std::vector<float> bw_expected;
|
||||||
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
||||||
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end);
|
||||||
EXPECT_THAT(rnn.GetBwOutput(),
|
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
|
||||||
ElementsAreArray(ArrayFloatNear(
|
|
||||||
bw_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as BlackBox test, but input is reshuffled to time_major format.
|
// Same as BlackBox test, but input is reshuffled to time_major format.
|
||||||
TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
||||||
auto params = GetParam();
|
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*fw_units=*/16, /*bw_units=*/16,
|
/*fw_units=*/16, /*bw_units=*/16,
|
||||||
/*input_size=*/8, /*aux_input_size=*/0,
|
/*input_size=*/8, /*aux_input_size=*/0,
|
||||||
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
||||||
/*time_major=*/true,
|
/*time_major=*/true,
|
||||||
/*merge_outputs=*/false, quantize_weights,
|
/*merge_outputs=*/false);
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
rnn.SetFwWeights(weights);
|
rnn.SetFwWeights(weights);
|
||||||
rnn.SetBwWeights(weights);
|
rnn.SetBwWeights(weights);
|
||||||
rnn.SetFwBias(biases);
|
rnn.SetFwBias(biases);
|
||||||
@ -939,26 +889,17 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
|
|||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
|
||||||
}
|
}
|
||||||
constexpr float kHybridTolerance = 3.57e-1;
|
EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
|
||||||
constexpr float kFloatTolerance = 1e-5;
|
|
||||||
EXPECT_THAT(
|
|
||||||
rnn.GetFwOutput(),
|
|
||||||
ElementsAreArray(ArrayFloatNear(
|
|
||||||
fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as BlackBox test, yet with merged outputs.
|
// Same as BlackBox test, yet with merged outputs.
|
||||||
TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
||||||
auto params = GetParam();
|
|
||||||
const bool quantize_weights = std::get<0>(params);
|
|
||||||
const bool asymmetric_quantize_inputs = std::get<1>(params);
|
|
||||||
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*fw_units=*/16, /*bw_units=*/16,
|
/*fw_units=*/16, /*bw_units=*/16,
|
||||||
/*input_size=*/8, /*aux_input_size=*/0,
|
/*input_size=*/8, /*aux_input_size=*/0,
|
||||||
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
/*aux_input_mode=*/AuxInputMode::kNoAuxInput,
|
||||||
/*time_major=*/false,
|
/*time_major=*/false,
|
||||||
/*merge_outputs=*/true, quantize_weights,
|
/*merge_outputs=*/true);
|
||||||
asymmetric_quantize_inputs);
|
|
||||||
rnn.SetFwWeights(weights);
|
rnn.SetFwWeights(weights);
|
||||||
rnn.SetBwWeights(weights);
|
rnn.SetBwWeights(weights);
|
||||||
rnn.SetFwBias(biases);
|
rnn.SetFwBias(biases);
|
||||||
@ -988,8 +929,7 @@ TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_THAT(rnn.GetFwOutput(),
|
EXPECT_THAT(rnn.GetFwOutput(),
|
||||||
ElementsAreArray(ArrayFloatNear(
|
ElementsAreArray(ArrayFloatNear(merged_expected)));
|
||||||
merged_expected, quantize_weights ? 1.42e-2 : 1e-5)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as BlackBox test, but input is reshuffled to time_major format.
|
// Same as BlackBox test, but input is reshuffled to time_major format.
|
||||||
|
@ -71,7 +71,6 @@ struct OpData {
|
|||||||
int32_t output_activation_max;
|
int32_t output_activation_max;
|
||||||
// The index of the temporary tensor where the quantized inputs are cached.
|
// The index of the temporary tensor where the quantized inputs are cached.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
bool compute_row_sums = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
@ -132,7 +131,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
// Instead, we allocate a new object to carry information from Prepare() to
|
// Instead, we allocate a new object to carry information from Prepare() to
|
||||||
// Eval().
|
// Eval().
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, /*tensors_to_add=*/5,
|
context->AddTensors(context, /*tensors_to_add=*/3,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -145,6 +144,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
auto* params =
|
auto* params =
|
||||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
// Check we have all the inputs and outputs we need.
|
// Check we have all the inputs and outputs we need.
|
||||||
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
|
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
|
||||||
// Shuffled formats need a workspace to store the shuffled input activations.
|
// Shuffled formats need a workspace to store the shuffled input activations.
|
||||||
@ -208,8 +208,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
if (input->type == kTfLiteFloat32 &&
|
if (input->type == kTfLiteFloat32 &&
|
||||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
|
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
data->compute_row_sums = true;
|
node->temporaries = TfLiteIntArrayCreate(3);
|
||||||
node->temporaries = TfLiteIntArrayCreate(5);
|
|
||||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||||
|
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
@ -246,28 +245,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[3] = data->scratch_tensor_index + 3;
|
|
||||||
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
|
|
||||||
input_offsets->type = kTfLiteInt32;
|
|
||||||
input_offsets->allocation_type = kTfLiteArenaRw;
|
|
||||||
if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
|
|
||||||
TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
|
|
||||||
input_offsets_size->data[0] = batch_size;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
|
|
||||||
input_offsets_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[4] = data->scratch_tensor_index + 4;
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
|
|
||||||
row_sums->type = kTfLiteInt32;
|
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int row_sums_dims[1] = {num_units};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
|
|
||||||
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
|
|
||||||
row_sums_size->data[0] = row_sums_dims[0];
|
|
||||||
TF_LITE_ENSURE_OK(
|
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resize output.
|
// Resize output.
|
||||||
@ -355,9 +332,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
|||||||
TfLiteFullyConnectedParams* params, OpData* data,
|
TfLiteFullyConnectedParams* params, OpData* data,
|
||||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||||
const TfLiteTensor* bias, TfLiteTensor* input_quantized,
|
const TfLiteTensor* bias, TfLiteTensor* input_quantized,
|
||||||
TfLiteTensor* scaling_factors,
|
TfLiteTensor* scaling_factors, TfLiteTensor* output) {
|
||||||
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
|
||||||
TfLiteTensor* input_offsets, TfLiteTensor* output) {
|
|
||||||
int total_input_size = 1;
|
int total_input_size = 1;
|
||||||
for (int i = 0; i < input->dims->size; i++) {
|
for (int i = 0; i < input->dims->size; i++) {
|
||||||
total_input_size *= input->dims->data[i];
|
total_input_size *= input->dims->data[i];
|
||||||
@ -388,39 +363,32 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
|||||||
// Quantize input from float to uint8 + quantization params (scaling factor).
|
// Quantize input from float to uint8 + quantization params (scaling factor).
|
||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
int32_t* input_offset_ptr = nullptr;
|
|
||||||
int32_t* row_sums_ptr = nullptr;
|
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
input_offset_ptr = GetTensorData<int32_t>(input_offsets);
|
|
||||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
|
||||||
}
|
|
||||||
int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
|
int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
|
||||||
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
const int8_t* filter_data = GetTensorData<int8_t>(filter);
|
||||||
const float* input_ptr = GetTensorData<float>(input);
|
|
||||||
// Quantize each batch independently.
|
// Quantize each batch independently.
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
input_ptr + offset, input_size, quant_data + offset,
|
|
||||||
&scaling_factors_ptr[b], &input_offset_ptr[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
input_ptr + offset, input_size, quant_data + offset, &unused_min,
|
GetTensorData<float>(input) + offset, input_size, quant_data + offset,
|
||||||
&unused_max, &scaling_factors_ptr[b]);
|
&unused_min, &unused_max, &scaling_factors_ptr[b]);
|
||||||
}
|
|
||||||
// Incorporate scaling of the filter.
|
// Incorporate scaling of the filter.
|
||||||
scaling_factors_ptr[b] *= filter->params.scale;
|
scaling_factors_ptr[b] *= filter->params.scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute output += weight * quantized_input
|
// Compute output += weight * quantized_input
|
||||||
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
|
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
||||||
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||||
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
|
batch_size, scratch, GetTensorData<float>(output),
|
||||||
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
|
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
|
#else
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||||
|
batch_size, GetTensorData<float>(output));
|
||||||
|
#endif
|
||||||
// Apply activation function to floats.
|
// Apply activation function to floats.
|
||||||
tensor_utils::ApplyActivationToVector(
|
tensor_utils::ApplyActivationToVector(
|
||||||
GetTensorData<float>(output), batch_size * num_units, params->activation,
|
GetTensorData<float>(output), batch_size * num_units, params->activation,
|
||||||
@ -493,12 +461,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
if (input->type == kTfLiteFloat32) {
|
if (input->type == kTfLiteFloat32) {
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
|
||||||
TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3);
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4);
|
|
||||||
return EvalHybrid(context, node, params, data, input, filter, bias,
|
return EvalHybrid(context, node, params, data, input, filter, bias,
|
||||||
input_quantized, scaling_factors, accum_scratch, row_sums,
|
input_quantized, scaling_factors, output);
|
||||||
input_offsets, output);
|
|
||||||
} else {
|
} else {
|
||||||
FullyConnectedParams op_params;
|
FullyConnectedParams op_params;
|
||||||
op_params.input_offset = input_offset;
|
op_params.input_offset = input_offset;
|
||||||
@ -626,6 +590,7 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|||||||
FullyConnectedParams op_params;
|
FullyConnectedParams op_params;
|
||||||
op_params.float_activation_min = output_activation_min;
|
op_params.float_activation_min = output_activation_min;
|
||||||
op_params.float_activation_max = output_activation_max;
|
op_params.float_activation_max = output_activation_max;
|
||||||
|
|
||||||
reference_ops::FullyConnected(
|
reference_ops::FullyConnected(
|
||||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||||
|
@ -286,8 +286,7 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
|
|||||||
public:
|
public:
|
||||||
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
|
HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
|
||||||
const TensorData& weights,
|
const TensorData& weights,
|
||||||
const TensorData& output = {TensorType_FLOAT32},
|
const TensorData& output = {TensorType_FLOAT32})
|
||||||
bool asymmetric_inputs = false)
|
|
||||||
: batches_(batches), units_(units) {
|
: batches_(batches), units_(units) {
|
||||||
int total_input_size = 1;
|
int total_input_size = 1;
|
||||||
for (size_t i = 0; i < input.shape.size(); ++i) {
|
for (size_t i = 0; i < input.shape.size(); ++i) {
|
||||||
@ -303,13 +302,10 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
output_ = AddOutput(output);
|
output_ = AddOutput(output);
|
||||||
|
|
||||||
auto options = CreateFullyConnectedOptions(
|
SetBuiltinOp(
|
||||||
builder_, ActivationFunctionType_RELU,
|
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
|
||||||
tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
|
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
|
||||||
false, asymmetric_inputs)
|
.Union());
|
||||||
.Union();
|
|
||||||
SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
|
|
||||||
BuiltinOptions_FullyConnectedOptions, options);
|
|
||||||
resolver_ = absl::make_unique<SingleOpResolver>(
|
resolver_ = absl::make_unique<SingleOpResolver>(
|
||||||
BuiltinOperator_FULLY_CONNECTED,
|
BuiltinOperator_FULLY_CONNECTED,
|
||||||
ops::builtin::Register_FULLY_CONNECTED_PIE());
|
ops::builtin::Register_FULLY_CONNECTED_PIE());
|
||||||
@ -871,66 +867,6 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) {
|
|||||||
/*max_abs_error=*/1.3f)));
|
/*max_abs_error=*/1.3f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) {
|
|
||||||
HybridFullyConnectedOpModel m(
|
|
||||||
/*units=*/3, /*batches=*/2,
|
|
||||||
/*input=*/{TensorType_FLOAT32, {2, 10}},
|
|
||||||
/*weights=*/
|
|
||||||
{TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32},
|
|
||||||
/*asymmetric_quantize_input*/ true); // Hybrid asymmetric
|
|
||||||
|
|
||||||
m.SetWeights({
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
|
||||||
});
|
|
||||||
m.SetBias({1, 2, 3});
|
|
||||||
|
|
||||||
m.SetInput({
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
|
||||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
|
||||||
});
|
|
||||||
|
|
||||||
m.Invoke();
|
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
|
||||||
{
|
|
||||||
24, 25, 26, //
|
|
||||||
58, 59, 60, //
|
|
||||||
},
|
|
||||||
/*max_abs_error=*/0.64f)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) {
|
|
||||||
HybridFullyConnectedOpModel m(
|
|
||||||
/*units=*/3, /*batches=*/2,
|
|
||||||
/*input=*/{TensorType_FLOAT32, {2, 10}},
|
|
||||||
/*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0},
|
|
||||||
{TensorType_FLOAT32},
|
|
||||||
/*asymmetric_quantize_input*/ true);
|
|
||||||
|
|
||||||
m.SetSignedWeights({
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
|
|
||||||
});
|
|
||||||
m.SetBias({1, 2, 3});
|
|
||||||
|
|
||||||
m.SetInput({
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
|
||||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
|
||||||
});
|
|
||||||
|
|
||||||
m.Invoke();
|
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
|
|
||||||
{
|
|
||||||
24, 25, 26, //
|
|
||||||
58, 59, 60, //
|
|
||||||
},
|
|
||||||
/*max_abs_error=*/1.3f)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
|
TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
|
||||||
// Note that it is not required that the first dimension be the number of
|
// Note that it is not required that the first dimension be the number of
|
||||||
// batches. All we care is that the input can be evenly distributed in
|
// batches. All we care is that the input can be evenly distributed in
|
||||||
|
@ -123,9 +123,7 @@ void RnnBatchStep(
|
|||||||
int num_units, int batch_size, int output_batch_leading_dim,
|
int num_units, int batch_size, int output_batch_leading_dim,
|
||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
float* hidden_state_ptr_batch, float* output_ptr_batch) {
|
||||||
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
|
||||||
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
|
|
||||||
RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
|
RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
|
||||||
/*aux_input_ptr_batch=*/nullptr,
|
/*aux_input_ptr_batch=*/nullptr,
|
||||||
/*aux_input_weights_ptr=*/nullptr,
|
/*aux_input_weights_ptr=*/nullptr,
|
||||||
@ -135,29 +133,7 @@ void RnnBatchStep(
|
|||||||
output_batch_leading_dim, activation, quantized_input_ptr_batch,
|
output_batch_leading_dim, activation, quantized_input_ptr_batch,
|
||||||
/*aux_quantized_input_ptr_batch=*/nullptr,
|
/*aux_quantized_input_ptr_batch=*/nullptr,
|
||||||
quantized_hidden_state_ptr_batch, scaling_factors,
|
quantized_hidden_state_ptr_batch, scaling_factors,
|
||||||
hidden_state_ptr_batch, output_ptr_batch,
|
hidden_state_ptr_batch, output_ptr_batch);
|
||||||
asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums,
|
|
||||||
compute_row_sums);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums,
|
|
||||||
int32_t* recurrent_row_sums, int32_t* row_sums,
|
|
||||||
const float* aux_input_ptr_batch, int num_units,
|
|
||||||
int input_size, int aux_input_size,
|
|
||||||
const int8_t* input_weights_ptr,
|
|
||||||
const int8_t* aux_input_weights_ptr,
|
|
||||||
const int8_t* recurrent_weights_ptr) {
|
|
||||||
memset(input_row_sums, 0, sizeof(int32_t) * num_units);
|
|
||||||
tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units,
|
|
||||||
input_size);
|
|
||||||
if (aux_input_ptr_batch) {
|
|
||||||
memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units);
|
|
||||||
tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums,
|
|
||||||
num_units, aux_input_size);
|
|
||||||
}
|
|
||||||
memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units);
|
|
||||||
tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums,
|
|
||||||
num_units, num_units);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RnnBatchStep(
|
void RnnBatchStep(
|
||||||
@ -170,31 +146,9 @@ void RnnBatchStep(
|
|||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* aux_quantized_input_ptr_batch,
|
int8_t* aux_quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
float* hidden_state_ptr_batch, float* output_ptr_batch) {
|
||||||
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
|
||||||
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
|
|
||||||
// Since the output batch rows may not be contiguous (output_batch_leading_dim
|
// Since the output batch rows may not be contiguous (output_batch_leading_dim
|
||||||
// != n_output), we unroll the batched operations where this is the case.
|
// != n_output), we unroll the batched operations where this is the case.
|
||||||
|
|
||||||
int32_t* input_row_sums = nullptr;
|
|
||||||
int32_t* aux_input_row_sums = nullptr;
|
|
||||||
int32_t* recurrent_row_sums = nullptr;
|
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
input_row_sums = row_sums;
|
|
||||||
aux_input_row_sums = row_sums;
|
|
||||||
if (aux_input_ptr_batch) {
|
|
||||||
aux_input_row_sums += num_units;
|
|
||||||
}
|
|
||||||
recurrent_row_sums = aux_input_row_sums + num_units;
|
|
||||||
if (*compute_row_sums) {
|
|
||||||
ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums,
|
|
||||||
row_sums, aux_input_ptr_batch, num_units, input_size,
|
|
||||||
aux_input_size, input_weights_ptr,
|
|
||||||
aux_input_weights_ptr, recurrent_weights_ptr);
|
|
||||||
*compute_row_sums = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (output_batch_leading_dim == num_units) {
|
if (output_batch_leading_dim == num_units) {
|
||||||
// Output = bias
|
// Output = bias
|
||||||
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
|
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
|
||||||
@ -209,25 +163,17 @@ void RnnBatchStep(
|
|||||||
// whichever is faster.
|
// whichever is faster.
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
input_ptr_batch + offset, input_size,
|
|
||||||
quantized_input_ptr_batch + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
input_ptr_batch + offset, input_size,
|
input_ptr_batch + offset, input_size,
|
||||||
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
&scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
|
||||||
scaling_factors[b] *= input_weights_scale;
|
scaling_factors[b] *= input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output += input * input_weights
|
// Output += input * input_weights
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
|
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
|
||||||
scaling_factors, batch_size, output_ptr_batch,
|
scaling_factors, batch_size, output_ptr_batch);
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch,
|
|
||||||
input_row_sums, compute_row_sums, /*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (aux_input_ptr_batch &&
|
if (aux_input_ptr_batch &&
|
||||||
@ -236,17 +182,10 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * aux_input_size;
|
const int offset = b * aux_input_size;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
aux_input_ptr_batch + offset, aux_input_size,
|
|
||||||
aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
aux_input_ptr_batch + offset, aux_input_size,
|
aux_input_ptr_batch + offset, aux_input_size,
|
||||||
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
&scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
|
||||||
scaling_factors[b] *= aux_input_weights_scale;
|
scaling_factors[b] *= aux_input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,9 +193,7 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_weights_ptr, num_units, aux_input_size,
|
aux_input_weights_ptr, num_units, aux_input_size,
|
||||||
aux_quantized_input_ptr_batch, scaling_factors, batch_size,
|
aux_quantized_input_ptr_batch, scaling_factors, batch_size,
|
||||||
output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
|
output_ptr_batch);
|
||||||
accum_scratch, aux_input_row_sums, compute_row_sums,
|
|
||||||
/*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
@ -266,17 +203,10 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * num_units;
|
const int offset = b * num_units;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
hidden_state_ptr_batch + offset, num_units,
|
|
||||||
quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
hidden_state_ptr_batch + offset, num_units,
|
hidden_state_ptr_batch + offset, num_units,
|
||||||
quantized_hidden_state_ptr_batch + offset, &unused_min,
|
quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
&unused_max, &scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
|
||||||
scaling_factors[b] *= recurrent_weights_scale;
|
scaling_factors[b] *= recurrent_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,9 +214,7 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_weights_ptr, num_units, num_units,
|
recurrent_weights_ptr, num_units, num_units,
|
||||||
quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
|
quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
|
||||||
output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
|
output_ptr_batch);
|
||||||
accum_scratch, recurrent_row_sums, compute_row_sums,
|
|
||||||
/*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output = activation(Output) and update hidden_state
|
// Output = activation(Output) and update hidden_state
|
||||||
@ -310,17 +238,10 @@ void RnnBatchStep(
|
|||||||
// whichever is faster.
|
// whichever is faster.
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
input_ptr_batch + offset, input_size,
|
|
||||||
quantized_input_ptr_batch + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
input_ptr_batch + offset, input_size,
|
input_ptr_batch + offset, input_size,
|
||||||
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
&scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
|
||||||
scaling_factors[b] *= input_weights_scale;
|
scaling_factors[b] *= input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,9 +250,7 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_weights_ptr, num_units, input_size,
|
input_weights_ptr, num_units, input_size,
|
||||||
quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
|
quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
|
||||||
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
|
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
|
||||||
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
|
|
||||||
input_row_sums, compute_row_sums, /*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,17 +260,10 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * aux_input_size;
|
const int offset = b * aux_input_size;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
aux_input_ptr_batch + offset, aux_input_size,
|
|
||||||
aux_quantized_input_ptr_batch + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
aux_input_ptr_batch + offset, aux_input_size,
|
aux_input_ptr_batch + offset, aux_input_size,
|
||||||
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
&scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
|
||||||
scaling_factors[b] *= aux_input_weights_scale;
|
scaling_factors[b] *= aux_input_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,9 +273,7 @@ void RnnBatchStep(
|
|||||||
aux_input_weights_ptr, num_units, aux_input_size,
|
aux_input_weights_ptr, num_units, aux_input_size,
|
||||||
aux_quantized_input_ptr_batch + k * aux_input_size,
|
aux_quantized_input_ptr_batch + k * aux_input_size,
|
||||||
&scaling_factors[k],
|
&scaling_factors[k],
|
||||||
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
|
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
|
||||||
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
|
|
||||||
aux_input_row_sums, compute_row_sums, /*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -374,17 +284,10 @@ void RnnBatchStep(
|
|||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * num_units;
|
const int offset = b * num_units;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
hidden_state_ptr_batch + offset, num_units,
|
|
||||||
quantized_hidden_state_ptr_batch + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
hidden_state_ptr_batch + offset, num_units,
|
hidden_state_ptr_batch + offset, num_units,
|
||||||
quantized_hidden_state_ptr_batch + offset, &unused_min,
|
quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
&unused_max, &scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
|
||||||
scaling_factors[b] *= recurrent_weights_scale;
|
scaling_factors[b] *= recurrent_weights_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,10 +296,8 @@ void RnnBatchStep(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_weights_ptr, num_units, num_units,
|
recurrent_weights_ptr, num_units, num_units,
|
||||||
quantized_hidden_state_ptr_batch + k * num_units,
|
quantized_hidden_state_ptr_batch + k * num_units,
|
||||||
&scaling_factors[k], /*n_batch=*/1,
|
&scaling_factors[k],
|
||||||
output_ptr_batch + k * output_batch_leading_dim,
|
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
|
||||||
/*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
|
|
||||||
recurrent_row_sums, compute_row_sums, /*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,9 +70,7 @@ void RnnBatchStep(
|
|||||||
int num_units, int batch_size, int output_batch_leading_dim,
|
int num_units, int batch_size, int output_batch_leading_dim,
|
||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
float* hidden_state_ptr_batch, float* output_ptr_batch);
|
||||||
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
|
||||||
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
|
|
||||||
|
|
||||||
void RnnBatchStep(
|
void RnnBatchStep(
|
||||||
const float* input_ptr_batch, const int8_t* input_weights_ptr,
|
const float* input_ptr_batch, const int8_t* input_weights_ptr,
|
||||||
@ -84,9 +82,7 @@ void RnnBatchStep(
|
|||||||
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
|
||||||
int8_t* aux_quantized_input_ptr_batch,
|
int8_t* aux_quantized_input_ptr_batch,
|
||||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||||
float* hidden_state_ptr_batch, float* output_ptr_batch,
|
float* hidden_state_ptr_batch, float* output_ptr_batch);
|
||||||
bool asymmetric_quantize_inputs, int32_t* zero_points,
|
|
||||||
int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums);
|
|
||||||
|
|
||||||
} // namespace kernel_utils
|
} // namespace kernel_utils
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -1310,13 +1310,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
|
const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
|
||||||
const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
|
const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
|
||||||
|
|
||||||
int32_t* row_sums_ptr = row_sums;
|
|
||||||
if (row_sums == nullptr) {
|
|
||||||
row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
|
|
||||||
memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows);
|
|
||||||
NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int batch = 0; batch < n_batch; ++batch) {
|
for (int batch = 0; batch < n_batch; ++batch) {
|
||||||
const float batch_scaling_factor = scaling_factors[batch];
|
const float batch_scaling_factor = scaling_factors[batch];
|
||||||
const int batch_input_offset = input_offset[batch];
|
const int batch_input_offset = input_offset[batch];
|
||||||
@ -1334,6 +1327,10 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
// Initialize the dot product sum for the row to 0.
|
// Initialize the dot product sum for the row to 0.
|
||||||
int32x4_t dotprod_32x4 = vmovq_n_s32(0);
|
int32x4_t dotprod_32x4 = vmovq_n_s32(0);
|
||||||
|
|
||||||
|
int32x4_t row_sum_32x4;
|
||||||
|
if (row_sums == nullptr) {
|
||||||
|
row_sum_32x4 = vmovq_n_s32(0);
|
||||||
|
}
|
||||||
// Prefetch the row to cache.
|
// Prefetch the row to cache.
|
||||||
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
||||||
3 /* temporal locality */);
|
3 /* temporal locality */);
|
||||||
@ -1361,6 +1358,10 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
prod_16x8 =
|
prod_16x8 =
|
||||||
vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
|
vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
|
||||||
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
||||||
|
if (row_sums == nullptr) {
|
||||||
|
const int16x8_t row_sum_16x8 = vpaddlq_s8(s2_8x16);
|
||||||
|
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
|
||||||
|
}
|
||||||
} // for col
|
} // for col
|
||||||
|
|
||||||
// Half iteration dealing only 8 elements
|
// Half iteration dealing only 8 elements
|
||||||
@ -1374,24 +1375,29 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
|||||||
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
|
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
|
||||||
const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
|
const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
|
||||||
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
|
||||||
|
if (row_sums == nullptr) {
|
||||||
|
const int16x8_t row_sum_16x8 = vmovl_s8(s2_8x8);
|
||||||
|
row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8);
|
||||||
|
}
|
||||||
col += (kWeightsPerNeonLane >> 1);
|
col += (kWeightsPerNeonLane >> 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
|
int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
|
||||||
|
int32_t row_sum = row_sums == nullptr ? AccumulateNeonLane(row_sum_32x4)
|
||||||
|
: row_sums[row];
|
||||||
|
|
||||||
// Postamble loop.
|
// Postamble loop.
|
||||||
for (; col < m_cols; ++col) {
|
for (; col < m_cols; ++col) {
|
||||||
dotprod += row_ptr[col] * aligned_vec[col];
|
dotprod += row_ptr[col] * aligned_vec[col];
|
||||||
|
if (row_sums == nullptr) {
|
||||||
|
row_sum += row_ptr[col];
|
||||||
|
}
|
||||||
} // for col
|
} // for col
|
||||||
dotprod -= row_sums_ptr[row] * batch_input_offset;
|
dotprod -= row_sum * batch_input_offset;
|
||||||
*result += dotprod * scale;
|
*result += dotprod * scale;
|
||||||
++result;
|
++result;
|
||||||
} // for row
|
} // for row
|
||||||
} // for batch
|
} // for batch
|
||||||
|
|
||||||
if (row_sums == nullptr) {
|
|
||||||
free(row_sums_ptr);
|
|
||||||
}
|
|
||||||
if (unaligned) {
|
if (unaligned) {
|
||||||
free(aligned_row_free);
|
free(aligned_row_free);
|
||||||
}
|
}
|
||||||
@ -1404,20 +1410,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
|||||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||||
bool* compute_row_sums, CpuBackendContext* context) {
|
bool* compute_row_sums, CpuBackendContext* context) {
|
||||||
if (input_offset == nullptr) {
|
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
|
||||||
if (context) {
|
|
||||||
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
|
||||||
scaling_factors, n_batch, scratch,
|
|
||||||
result, context);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
|
||||||
scaling_factors, n_batch, result);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (compute_row_sums == nullptr || *compute_row_sums) {
|
if (compute_row_sums == nullptr || *compute_row_sums) {
|
||||||
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
||||||
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
||||||
@ -1427,7 +1419,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
if (context != nullptr && m_rows % 4 == 0) {
|
if (m_rows % 4 == 0) {
|
||||||
const int32_t* bias = static_cast<const int32_t*>(nullptr);
|
const int32_t* bias = static_cast<const int32_t*>(nullptr);
|
||||||
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
|
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0,
|
||||||
scratch, context);
|
scratch, context);
|
||||||
@ -1471,9 +1463,9 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
|||||||
for (; i < total_size; i++) {
|
for (; i < total_size; i++) {
|
||||||
const float batch_scaling_factor = scaling_factors[i / m_rows];
|
const float batch_scaling_factor = scaling_factors[i / m_rows];
|
||||||
const int32_t zero_point = input_offset[i / m_rows];
|
const int32_t zero_point = input_offset[i / m_rows];
|
||||||
int32_t dotprod = *(scratch_ptr++);
|
int32_t x = *(scratch_ptr++);
|
||||||
dotprod -= row_sums[i % m_rows] * zero_point;
|
x -= row_sums[i % m_rows] * zero_point;
|
||||||
*result += dotprod * batch_scaling_factor;
|
*result += x * batch_scaling_factor;
|
||||||
++result;
|
++result;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
@ -167,11 +167,6 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
|||||||
const float* __restrict__ scaling_factors, int n_batch,
|
const float* __restrict__ scaling_factors, int n_batch,
|
||||||
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
||||||
const int32_t* __restrict__ input_offset) {
|
const int32_t* __restrict__ input_offset) {
|
||||||
if (input_offset == nullptr) {
|
|
||||||
SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
|
||||||
scaling_factors, n_batch, result);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
static constexpr std::intptr_t kBlockSize = 16;
|
static constexpr std::intptr_t kBlockSize = 16;
|
||||||
for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
|
for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
|
||||||
const float batch_scaling_factor = scaling_factors[batch];
|
const float batch_scaling_factor = scaling_factors[batch];
|
||||||
|
@ -196,11 +196,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
|
|||||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||||
bool* compute_row_sums, CpuBackendContext* context) {
|
bool* compute_row_sums, CpuBackendContext* context) {
|
||||||
if (input_offset == nullptr) {
|
|
||||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
|
||||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!compute_row_sums || *compute_row_sums) {
|
if (!compute_row_sums || *compute_row_sums) {
|
||||||
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
||||||
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
||||||
|
@ -223,8 +223,7 @@ inline void EvalHybridSVDF(
|
|||||||
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
|
const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
|
||||||
const TfLiteTensor* bias, const TfLiteSVDFParams* params,
|
const TfLiteTensor* bias, const TfLiteSVDFParams* params,
|
||||||
TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
|
TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
|
||||||
TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output,
|
TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
|
||||||
TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) {
|
|
||||||
const int rank = params->rank;
|
const int rank = params->rank;
|
||||||
const int batch_size = input->dims->data[0];
|
const int batch_size = input->dims->data[0];
|
||||||
const int input_size = input->dims->data[1];
|
const int input_size = input->dims->data[1];
|
||||||
@ -245,13 +244,6 @@ inline void EvalHybridSVDF(
|
|||||||
|
|
||||||
float* output_ptr = GetTensorData<float>(output);
|
float* output_ptr = GetTensorData<float>(output);
|
||||||
|
|
||||||
int32_t* zero_points_ptr = nullptr;
|
|
||||||
int32_t* row_sums_ptr = nullptr;
|
|
||||||
if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
|
|
||||||
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
|
||||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the weights scale.
|
// Initialize the weights scale.
|
||||||
const float weights_feature_scale = weights_feature->params.scale;
|
const float weights_feature_scale = weights_feature->params.scale;
|
||||||
|
|
||||||
@ -266,30 +258,21 @@ inline void EvalHybridSVDF(
|
|||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
|
if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
|
||||||
// Quantize input from float to int8.
|
// Quantize input from float to int8.
|
||||||
|
float unused_min, unused_max;
|
||||||
for (int b = 0; b < batch_size; ++b) {
|
for (int b = 0; b < batch_size; ++b) {
|
||||||
const int offset = b * input_size;
|
const int offset = b * input_size;
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
input_ptr + offset, input_size, quantized_input_ptr + offset,
|
|
||||||
&scaling_factors_ptr[b], &zero_points_ptr[b]);
|
|
||||||
} else {
|
|
||||||
// Quantize input from float to int8.
|
|
||||||
float unused_min, unused_max;
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
input_ptr + offset, input_size, quantized_input_ptr + offset,
|
input_ptr + offset, input_size, quantized_input_ptr + offset,
|
||||||
&unused_min, &unused_max, &scaling_factors_ptr[b]);
|
&unused_min, &unused_max, &scaling_factors_ptr[b]);
|
||||||
}
|
|
||||||
scaling_factors_ptr[b] *= weights_feature_scale;
|
scaling_factors_ptr[b] *= weights_feature_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute conv1d(inputs, weights_feature).
|
// Compute conv1d(inputs, weights_feature).
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
|
weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
|
||||||
scaling_factors_ptr, batch_size, scratch_ptr,
|
scaling_factors_ptr, batch_size, scratch_ptr);
|
||||||
/*per_channel_scale=*/nullptr, zero_points_ptr,
|
|
||||||
reinterpret_cast<int32_t*>(scratch_ptr), row_sums_ptr, compute_row_sums,
|
|
||||||
/*context=*/nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the latest activation from scratch into activation_state:
|
// Copy the latest activation from scratch into activation_state:
|
||||||
// The last, i.e. (memory_size-1)th entry for each batch, and filter.
|
// The last, i.e. (memory_size-1)th entry for each batch, and filter.
|
||||||
for (int i = 0; i < batch_size * num_filters; ++i) {
|
for (int i = 0; i < batch_size * num_filters; ++i) {
|
||||||
|
@ -55,7 +55,6 @@ struct OpData {
|
|||||||
// These fields are only used by full kernel.
|
// These fields are only used by full kernel.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
||||||
bool compute_row_sums;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace full {
|
namespace full {
|
||||||
@ -728,7 +727,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
|||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
op_data->kernel_type = kTfLiteLSTMFullKernel;
|
op_data->kernel_type = kTfLiteLSTMFullKernel;
|
||||||
context->AddTensors(context, /*tensors_to_add=*/10,
|
context->AddTensors(context, /*tensors_to_add=*/8,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -1237,7 +1236,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(10);
|
node->temporaries = TfLiteIntArrayCreate(8);
|
||||||
} else if (is_integer) {
|
} else if (is_integer) {
|
||||||
if (is_8x8_16) {
|
if (is_8x8_16) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(6);
|
node->temporaries = TfLiteIntArrayCreate(6);
|
||||||
@ -1274,7 +1273,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
op_data->compute_row_sums = true;
|
|
||||||
// Allocate temporary tensors to store quantized values of input,
|
// Allocate temporary tensors to store quantized values of input,
|
||||||
// activation_state and cell_state tensors.
|
// activation_state and cell_state tensors.
|
||||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||||
@ -1372,41 +1370,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[8] = op_data->scratch_tensor_index + 8;
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
|
|
||||||
zero_points->type = kTfLiteFloat32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
int zero_points_dims[1] = {n_batch};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = n_batch;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
node->temporaries->data[9] = op_data->scratch_tensor_index + 9;
|
|
||||||
const TfLiteTensor* input_to_input_weights =
|
|
||||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
|
||||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
|
||||||
int row_sums_rows = use_cifg ? 6 : 8;
|
|
||||||
const TfLiteTensor* projection_weights =
|
|
||||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
|
||||||
if (projection_weights != nullptr) {
|
|
||||||
row_sums_rows += ceil(n_output / n_cell);
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
|
|
||||||
row_sums->type = kTfLiteInt32;
|
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
const int row_sums_dims[2] = {row_sums_rows, n_cell};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
|
||||||
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
row_sums_size->data[0] = row_sums_dims[0];
|
|
||||||
row_sums_size->data[1] = row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(
|
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_integer) {
|
if (is_integer) {
|
||||||
@ -1593,9 +1556,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, /*index=*/6);
|
GetTemporary(context, node, /*index=*/6);
|
||||||
TfLiteTensor* output_scratch_buffer =
|
TfLiteTensor* output_scratch_buffer =
|
||||||
GetTemporary(context, node, /*index=*/7);
|
GetTemporary(context, node, /*index=*/7);
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8);
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9);
|
|
||||||
const int row_sums_size = row_sums->dims->data[0];
|
|
||||||
return lstm_eval::EvalHybrid(
|
return lstm_eval::EvalHybrid(
|
||||||
input, input_to_input_weights, input_to_forget_weights,
|
input, input_to_input_weights, input_to_forget_weights,
|
||||||
input_to_cell_weights, input_to_output_weights,
|
input_to_cell_weights, input_to_output_weights,
|
||||||
@ -1617,8 +1577,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
input_quantized,
|
input_quantized,
|
||||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||||
cell_state_quantized, activation_state, cell_state,
|
cell_state_quantized, activation_state, cell_state,
|
||||||
output_scratch_buffer, output, zero_points, row_sums, row_sums_size,
|
output_scratch_buffer, output,
|
||||||
&op_data->compute_row_sums,
|
|
||||||
CpuBackendContext::GetFromContext(context));
|
CpuBackendContext::GetFromContext(context));
|
||||||
} else {
|
} else {
|
||||||
const int num_intermediate_tensors = node->intermediates->size;
|
const int num_intermediate_tensors = node->intermediates->size;
|
||||||
|
@ -33,95 +33,26 @@ namespace builtin {
|
|||||||
namespace lstm_eval {
|
namespace lstm_eval {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void ComputeRowSums(
|
|
||||||
int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
|
|
||||||
int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
|
|
||||||
int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
|
|
||||||
int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
|
|
||||||
int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
|
|
||||||
int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
|
|
||||||
int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
|
|
||||||
int n_input, int n_aux_input, int n_output,
|
|
||||||
const int8_t* input_to_input_weights_ptr,
|
|
||||||
const int8_t* input_to_forget_weights_ptr,
|
|
||||||
const int8_t* input_to_cell_weights_ptr,
|
|
||||||
const int8_t* input_to_output_weights_ptr,
|
|
||||||
const int8_t* aux_input_to_input_weights_ptr,
|
|
||||||
const int8_t* aux_input_to_forget_weights_ptr,
|
|
||||||
const int8_t* aux_input_to_cell_weights_ptr,
|
|
||||||
const int8_t* aux_input_to_output_weights_ptr,
|
|
||||||
const int8_t* recurrent_to_input_weights_ptr,
|
|
||||||
const int8_t* recurrent_to_forget_weights_ptr,
|
|
||||||
const int8_t* recurrent_to_cell_weights_ptr,
|
|
||||||
const int8_t* recurrent_to_output_weights_ptr,
|
|
||||||
const int8_t* projection_weights_ptr, bool use_cifg,
|
|
||||||
const float* aux_input_ptr) {
|
|
||||||
// Compute the row sums for dequantization
|
|
||||||
if (!use_cifg) {
|
|
||||||
memset(input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
|
|
||||||
input_to_input_row_sums, n_cell, n_input);
|
|
||||||
}
|
|
||||||
memset(input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
|
|
||||||
input_to_forget_row_sums, n_cell, n_input);
|
|
||||||
memset(input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
|
|
||||||
input_to_cell_row_sums, n_cell, n_input);
|
|
||||||
memset(input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
|
|
||||||
input_to_output_row_sums, n_cell, n_input);
|
|
||||||
|
|
||||||
if (aux_input_ptr) {
|
|
||||||
if (!use_cifg) {
|
|
||||||
memset(aux_input_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
|
|
||||||
aux_input_to_input_row_sums, n_cell,
|
|
||||||
n_aux_input);
|
|
||||||
}
|
|
||||||
memset(aux_input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
|
|
||||||
aux_input_to_forget_row_sums, n_cell,
|
|
||||||
n_aux_input);
|
|
||||||
memset(aux_input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
|
|
||||||
aux_input_to_cell_row_sums, n_cell,
|
|
||||||
n_aux_input);
|
|
||||||
memset(aux_input_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
|
|
||||||
aux_input_to_output_row_sums, n_cell,
|
|
||||||
n_aux_input);
|
|
||||||
}
|
|
||||||
if (!use_cifg) {
|
|
||||||
memset(recurrent_to_input_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
|
|
||||||
recurrent_to_input_row_sums, n_cell,
|
|
||||||
n_output);
|
|
||||||
}
|
|
||||||
memset(recurrent_to_forget_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
|
|
||||||
recurrent_to_forget_row_sums, n_cell,
|
|
||||||
n_output);
|
|
||||||
memset(recurrent_to_cell_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
|
|
||||||
recurrent_to_cell_row_sums, n_cell,
|
|
||||||
n_output);
|
|
||||||
memset(recurrent_to_output_row_sums, 0, sizeof(int32_t) * n_cell);
|
|
||||||
tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
|
|
||||||
recurrent_to_output_row_sums, n_cell,
|
|
||||||
n_output);
|
|
||||||
|
|
||||||
if (projection_weights_ptr != nullptr) {
|
|
||||||
memset(projection_weights_row_sums, 0, sizeof(int32_t) * n_output);
|
|
||||||
tensor_utils::ReductionSumVector(
|
|
||||||
projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline float GetTensorScale(const TfLiteTensor* tensor) {
|
inline float GetTensorScale(const TfLiteTensor* tensor) {
|
||||||
return tensor == nullptr ? 1.0f : tensor->params.scale;
|
return tensor == nullptr ? 1.0f : tensor->params.scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||||
|
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||||
|
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
// TODO(b/148289189) Remove when Ruy GEMV is the default.
|
||||||
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch,
|
||||||
|
result, context);
|
||||||
|
#else
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Performs an LSTM batch inference step for input specified by input_ptr.
|
// Performs an LSTM batch inference step for input specified by input_ptr.
|
||||||
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
||||||
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
||||||
@ -542,8 +473,6 @@ inline void LstmStepHybrid(
|
|||||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
||||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
||||||
int32_t* zero_points, int32_t* row_sums, int row_sums_size,
|
|
||||||
bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
|
||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||||
// Since we have already checked that weights are all there or none, we
|
// Since we have already checked that weights are all there or none, we
|
||||||
@ -574,131 +503,53 @@ inline void LstmStepHybrid(
|
|||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t* input_to_input_row_sums = nullptr;
|
// For each batch and cell: compute input_weight * input.
|
||||||
int32_t* input_to_forget_row_sums = nullptr;
|
// Skip if input is all zeros.
|
||||||
int32_t* input_to_cell_row_sums = nullptr;
|
|
||||||
int32_t* input_to_output_row_sums = nullptr;
|
|
||||||
int32_t* aux_input_to_input_row_sums = nullptr;
|
|
||||||
int32_t* aux_input_to_forget_row_sums = nullptr;
|
|
||||||
int32_t* aux_input_to_cell_row_sums = nullptr;
|
|
||||||
int32_t* aux_input_to_output_row_sums = nullptr;
|
|
||||||
int32_t* recurrent_to_input_row_sums = nullptr;
|
|
||||||
int32_t* recurrent_to_forget_row_sums = nullptr;
|
|
||||||
int32_t* recurrent_to_cell_row_sums = nullptr;
|
|
||||||
int32_t* recurrent_to_output_row_sums = nullptr;
|
|
||||||
int32_t* projection_weights_row_sums = nullptr;
|
|
||||||
|
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
int num_row_sums = use_cifg ? 6 : 8;
|
|
||||||
if (aux_input_ptr != nullptr) {
|
|
||||||
num_row_sums += use_cifg ? 3 : 4;
|
|
||||||
}
|
|
||||||
if (projection_weights_ptr != nullptr) {
|
|
||||||
num_row_sums += ceil(n_output / n_cell);
|
|
||||||
}
|
|
||||||
TF_LITE_ASSERT(row_sums_size == num_row_sums);
|
|
||||||
input_to_input_row_sums = row_sums;
|
|
||||||
input_to_forget_row_sums =
|
|
||||||
use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
|
|
||||||
input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
|
|
||||||
input_to_output_row_sums = input_to_cell_row_sums + n_cell;
|
|
||||||
if (aux_input_ptr != nullptr) {
|
|
||||||
aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
|
|
||||||
aux_input_to_forget_row_sums = use_cifg
|
|
||||||
? aux_input_to_input_row_sums
|
|
||||||
: aux_input_to_input_row_sums + n_cell;
|
|
||||||
aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
|
|
||||||
aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
|
|
||||||
}
|
|
||||||
recurrent_to_input_row_sums = aux_input_ptr
|
|
||||||
? aux_input_to_output_row_sums + n_cell
|
|
||||||
: input_to_output_row_sums + n_cell;
|
|
||||||
recurrent_to_forget_row_sums = use_cifg
|
|
||||||
? recurrent_to_input_row_sums
|
|
||||||
: recurrent_to_input_row_sums + n_cell;
|
|
||||||
recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
|
|
||||||
recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
|
|
||||||
if (projection_weights_ptr != nullptr) {
|
|
||||||
projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
|
|
||||||
}
|
|
||||||
if (*compute_row_sums) {
|
|
||||||
ComputeRowSums(
|
|
||||||
input_to_input_row_sums, input_to_forget_row_sums,
|
|
||||||
input_to_cell_row_sums, input_to_output_row_sums,
|
|
||||||
aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
|
|
||||||
aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
|
|
||||||
recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
|
|
||||||
recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
|
|
||||||
projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
|
|
||||||
n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
|
|
||||||
input_to_cell_weights_ptr, input_to_output_weights_ptr,
|
|
||||||
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
|
|
||||||
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
|
|
||||||
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
|
|
||||||
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
|
|
||||||
projection_weights_ptr, use_cifg, aux_input_ptr);
|
|
||||||
*compute_row_sums = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
|
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_input;
|
const int offset = b * n_input;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
input_ptr + offset, n_input, quantized_input_ptr + offset,
|
|
||||||
&scaling_factors[b], &zero_points[b]);
|
|
||||||
} else {
|
|
||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
input_ptr + offset, n_input, quantized_input_ptr + offset,
|
input_ptr + offset, n_input, quantized_input_ptr + offset,
|
||||||
&unused_min, &unused_max, &scaling_factors[b]);
|
&unused_min, &unused_max, &scaling_factors[b]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_input_weights_scale;
|
scaling_factors[b] * input_to_input_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, input_gate_scratch,
|
product_scaling_factors, n_batch, accum_scratch_ptr,
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
input_gate_scratch, context);
|
||||||
input_to_input_row_sums, compute_row_sums, context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_forget_weights_scale;
|
scaling_factors[b] * input_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, forget_gate_scratch,
|
product_scaling_factors, n_batch, accum_scratch_ptr,
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
forget_gate_scratch, context);
|
||||||
input_to_forget_row_sums, compute_row_sums, context);
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_cell_weights_scale;
|
scaling_factors[b] * input_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, cell_scratch,
|
product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch,
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
context);
|
||||||
input_to_cell_row_sums, compute_row_sums, context);
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_output_weights_scale;
|
scaling_factors[b] * input_to_output_weights_scale;
|
||||||
}
|
}
|
||||||
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
product_scaling_factors, n_batch, output_gate_scratch,
|
product_scaling_factors, n_batch, accum_scratch_ptr,
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
output_gate_scratch, context);
|
||||||
input_to_output_row_sums, compute_row_sums, context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||||
@ -707,131 +558,98 @@ inline void LstmStepHybrid(
|
|||||||
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
|
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_aux_input;
|
const int offset = b * n_aux_input;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
aux_input_ptr + offset, n_aux_input,
|
|
||||||
quantized_aux_input_ptr + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
aux_input_ptr + offset, n_aux_input,
|
aux_input_ptr + offset, n_aux_input, quantized_aux_input_ptr + offset,
|
||||||
quantized_aux_input_ptr + offset, &unused_min, &unused_max,
|
&unused_min, &unused_max, &scaling_factors[b]);
|
||||||
&scaling_factors[b]);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_input_weights_scale;
|
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, input_gate_scratch, context);
|
||||||
accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, forget_gate_scratch, context);
|
||||||
accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
row_sums += n_cell;
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
accum_scratch_ptr, cell_scratch, context);
|
||||||
aux_input_to_cell_row_sums, compute_row_sums, context);
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * aux_input_to_output_weights_scale;
|
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||||
}
|
}
|
||||||
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||||
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, output_gate_scratch, context);
|
||||||
accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_output;
|
const int offset = b * n_output;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
output_state_ptr + offset, n_output,
|
|
||||||
quantized_output_state_ptr + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
|
||||||
output_state_ptr + offset, n_output,
|
quantized_output_state_ptr + offset,
|
||||||
quantized_output_state_ptr + offset, &unused_min, &unused_max,
|
&unused_min, &unused_max,
|
||||||
&scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// For each batch and cell: compute recurrent_weight * output_state.
|
// For each batch and cell: compute recurrent_weight * output_state.
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, input_gate_scratch, context);
|
||||||
accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, forget_gate_scratch, context);
|
||||||
accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
cell_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, cell_scratch, context);
|
||||||
accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
accum_scratch_ptr, output_gate_scratch, context);
|
||||||
accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each batch and cell: update input gate.
|
// For each batch and cell: update input gate.
|
||||||
@ -952,32 +770,22 @@ inline void LstmStepHybrid(
|
|||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
const int offset = b * n_cell;
|
const int offset = b * n_cell;
|
||||||
if (asymmetric_quantize_inputs) {
|
|
||||||
tensor_utils::AsymmetricQuantizeFloats(
|
|
||||||
output_gate_scratch + offset, n_cell,
|
|
||||||
quantized_cell_state_ptr + offset, &scaling_factors[b],
|
|
||||||
&zero_points[b]);
|
|
||||||
} else {
|
|
||||||
float unused_min, unused_max;
|
float unused_min, unused_max;
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
output_gate_scratch + offset, n_cell,
|
output_gate_scratch + offset, n_cell,
|
||||||
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
||||||
&scaling_factors[b]);
|
&scaling_factors[b]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * projection_weights_scale;
|
scaling_factors[b] * projection_weights_scale;
|
||||||
}
|
}
|
||||||
for (int b = 0; b < n_batch; b++) {
|
for (int b = 0; b < n_batch; b++) {
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
MatrixBatchVectorMultiplyAccumulate(
|
||||||
projection_weights_ptr, n_output, n_cell,
|
projection_weights_ptr, n_output, n_cell,
|
||||||
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
||||||
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
|
/*n_batch=*/1, accum_scratch_ptr,
|
||||||
/*per_channel_scale=*/nullptr,
|
output_ptr + b * output_batch_leading_dim, context);
|
||||||
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,
|
|
||||||
accum_scratch_ptr, projection_weights_row_sums, compute_row_sums,
|
|
||||||
context);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params->proj_clip > 0.0) {
|
if (params->proj_clip > 0.0) {
|
||||||
@ -1807,8 +1615,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||||
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
|
TfLiteTensor* output, CpuBackendContext* context) {
|
||||||
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) {
|
|
||||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||||
const int n_input = input->dims->data[input->dims->size - 1];
|
const int n_input = input->dims->data[input->dims->size - 1];
|
||||||
int max_time, n_batch;
|
int max_time, n_batch;
|
||||||
@ -1847,14 +1654,6 @@ TfLiteStatus EvalHybrid(
|
|||||||
|
|
||||||
const int output_batch_leading_dim =
|
const int output_batch_leading_dim =
|
||||||
output->dims->data[output->dims->size - 1];
|
output->dims->data[output->dims->size - 1];
|
||||||
|
|
||||||
int32_t* zero_points_ptr = nullptr;
|
|
||||||
int32_t* row_sums_ptr = nullptr;
|
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
|
||||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
// Feed the sequence into the LSTM step-by-step.
|
// Feed the sequence into the LSTM step-by-step.
|
||||||
const int input_step = n_batch * n_input;
|
const int input_step = n_batch * n_input;
|
||||||
@ -1922,9 +1721,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<int8_t>(output_state_quantized),
|
GetTensorData<int8_t>(output_state_quantized),
|
||||||
GetTensorData<int8_t>(cell_state_quantized),
|
GetTensorData<int8_t>(cell_state_quantized),
|
||||||
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
||||||
GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
|
GetTensorData<int32_t>(output_scratch_buffer), output_ptr, context);
|
||||||
zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
|
|
||||||
params->asymmetric_quantize_inputs, context);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int b = 0; b < n_batch; b++) {
|
for (int b = 0; b < n_batch; b++) {
|
||||||
@ -2009,8 +1806,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<int8_t>(output_state_quantized),
|
GetTensorData<int8_t>(output_state_quantized),
|
||||||
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
|
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
|
||||||
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
|
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
|
||||||
output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size,
|
output_ptr, context);
|
||||||
compute_row_sums, params->asymmetric_quantize_inputs, context);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,8 +156,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||||
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
|
TfLiteTensor* output, CpuBackendContext* context);
|
||||||
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context);
|
|
||||||
|
|
||||||
TfLiteStatus EvalInteger8x8_16(
|
TfLiteStatus EvalInteger8x8_16(
|
||||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||||
|
@ -38,8 +38,7 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
bool use_peephole, bool use_projection_weights,
|
bool use_peephole, bool use_projection_weights,
|
||||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
bool use_projection_bias, float cell_clip, float proj_clip,
|
||||||
const std::vector<std::vector<int>>& input_shapes,
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
const TensorType weight_type, bool is_layer_norm,
|
const TensorType weight_type, bool is_layer_norm)
|
||||||
bool asymmetric_quantize_inputs = false)
|
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_cell_(n_cell),
|
n_cell_(n_cell),
|
||||||
@ -130,11 +129,9 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
|
|
||||||
SetBuiltinOp(
|
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
||||||
BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
|
||||||
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
|
cell_clip, proj_clip)
|
||||||
proj_clip, ::tflite::LSTMKernelType_FULL,
|
|
||||||
asymmetric_quantize_inputs)
|
|
||||||
.Union());
|
.Union());
|
||||||
|
|
||||||
// Do not apply delegate yet since tensor values are not known (and more
|
// Do not apply delegate yet since tensor values are not known (and more
|
||||||
@ -318,7 +315,7 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
const TensorType weight_type_;
|
const TensorType weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseLstmTest : public ::testing::TestWithParam<bool> {
|
class BaseLstmTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
// Weights of the LSTM model. Some are optional.
|
// Weights of the LSTM model. Some are optional.
|
||||||
std::vector<float> input_to_input_weights_;
|
std::vector<float> input_to_input_weights_;
|
||||||
@ -568,11 +565,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -610,20 +604,14 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false, GetParam());
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
|
|
||||||
class NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
: public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
|
|
||||||
|
|
||||||
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -661,7 +649,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false, GetParam());
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
@ -757,11 +745,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -799,18 +784,13 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false, GetParam());
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
class CifgNoPeepholeNoProjectionNoClippingLstmInt8Test
|
|
||||||
: public CifgNoPeepholeNoProjectionNoClippingLstmTest {};
|
|
||||||
|
|
||||||
TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
// n_cell and n_output have the same size when there is no projection.
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
@ -848,7 +828,7 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false, GetParam());
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
@ -1494,60 +1474,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
|
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
||||||
HybridLstmBlackBoxTestUint8) {
|
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 2;
|
|
||||||
const int n_input = 5;
|
|
||||||
const int n_cell = 20;
|
|
||||||
const int n_output = 16;
|
|
||||||
|
|
||||||
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
|
|
||||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
|
||||||
/*use_projection_weights=*/true,
|
|
||||||
/*use_projection_bias=*/false,
|
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
||||||
{
|
|
||||||
{n_batch, n_input}, // input tensor
|
|
||||||
|
|
||||||
{n_cell, n_input}, // input_to_input_weight tensor
|
|
||||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
||||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
||||||
{n_cell, n_input}, // input_to_output_weight tensor
|
|
||||||
|
|
||||||
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
||||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
||||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
||||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
||||||
|
|
||||||
{n_cell}, // cell_to_input_weight tensor
|
|
||||||
{n_cell}, // cell_to_forget_weight tensor
|
|
||||||
{n_cell}, // cell_to_output_weight tensor
|
|
||||||
|
|
||||||
{n_cell}, // input_gate_bias tensor
|
|
||||||
{n_cell}, // forget_gate_bias tensor
|
|
||||||
{n_cell}, // cell_bias tensor
|
|
||||||
{n_cell}, // output_gate_bias tensor
|
|
||||||
|
|
||||||
{n_output, n_cell}, // projection_weight tensor
|
|
||||||
{0}, // projection_bias tensor
|
|
||||||
},
|
|
||||||
/*weight_type=*/TensorType_UINT8,
|
|
||||||
/*is_layer_norm=*/false, GetParam());
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
|
||||||
}
|
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionNoClippingLstmInt8Test
|
|
||||||
: public NoCifgPeepholeProjectionNoClippingLstmTest {};
|
|
||||||
|
|
||||||
TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
|
|
||||||
HybridLstmBlackBoxTestInt8) {
|
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1584,9 +1511,52 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
|
|||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false, GetParam());
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
||||||
|
HybridLstmBlackBoxTestUint8) {
|
||||||
|
const int n_batch = 2;
|
||||||
|
const int n_input = 5;
|
||||||
|
const int n_cell = 20;
|
||||||
|
const int n_output = 16;
|
||||||
|
|
||||||
|
LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
|
||||||
|
/*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
|
/*use_projection_weights=*/true,
|
||||||
|
/*use_projection_bias=*/false,
|
||||||
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
|
{
|
||||||
|
{n_batch, n_input}, // input tensor
|
||||||
|
|
||||||
|
{n_cell, n_input}, // input_to_input_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell}, // cell_to_input_weight tensor
|
||||||
|
{n_cell}, // cell_to_forget_weight tensor
|
||||||
|
{n_cell}, // cell_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell}, // input_gate_bias tensor
|
||||||
|
{n_cell}, // forget_gate_bias tensor
|
||||||
|
{n_cell}, // cell_bias tensor
|
||||||
|
{n_cell}, // output_gate_bias tensor
|
||||||
|
|
||||||
|
{n_output, n_cell}, // projection_weight tensor
|
||||||
|
{0}, // projection_bias tensor
|
||||||
|
},
|
||||||
|
/*weight_type=*/TensorType_UINT8,
|
||||||
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||||
@ -1723,11 +1693,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
HybridLayerNormLstmBlackBoxTestUint8) {
|
HybridLayerNormLstmBlackBoxTestUint8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
@ -1774,7 +1741,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/true, GetParam());
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
lstm_golden_output_ = {{
|
lstm_golden_output_ = {{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
@ -1793,14 +1760,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*tolerance=*/0.0010907);
|
/*tolerance=*/0.0010907);
|
||||||
}
|
}
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test
|
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
: public NoCifgPeepholeProjectionNoClippingLayerNormLstmTest {};
|
|
||||||
|
|
||||||
TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
|
||||||
HybridLayerNormLstmBlackBoxTestInt8) {
|
HybridLayerNormLstmBlackBoxTestInt8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
@ -1847,24 +1808,22 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/true, GetParam());
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
// Goldens are calculated from weight_type=TensorType_FLOAT32.
|
|
||||||
lstm_golden_output_ = {{
|
lstm_golden_output_ = {{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0244077, 0.128027, -0.00170918, // seq 0
|
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||||
0.0137642, 0.140751, 0.0395835, // seq 1
|
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||||
-0.00459233, 0.155278, 0.0837378, // seq 2
|
-0.0048839, 0.155096, 0.0840309, // seq 2
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||||
-0.00692428, 0.0848741, 0.063445, // seq 0
|
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
||||||
-0.00403911, 0.139963, 0.072681, // seq 1
|
-0.00448382, 0.139278, 0.0737372, // seq 1
|
||||||
0.00752708, 0.161903, 0.0561371, // seq 2
|
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
/*tolerance=*/1.06e-3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
|
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
|
||||||
@ -1981,11 +1940,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
HybridLayerNormLstmBlackBoxTestUint8) {
|
HybridLayerNormLstmBlackBoxTestUint8) {
|
||||||
if (SingleOpModel::GetForceUseNnapi() && GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 4;
|
const int n_cell = 4;
|
||||||
@ -2032,7 +1988,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/true, GetParam());
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
// Verify the final output.
|
// Verify the final output.
|
||||||
lstm_golden_output_ = {
|
lstm_golden_output_ = {
|
||||||
@ -2053,10 +2009,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*tolerance=*/0.000902065);
|
/*tolerance=*/0.000902065);
|
||||||
}
|
}
|
||||||
|
|
||||||
class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test
|
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
: public CifgPeepholeProjectionNoClippingLayerNormLstmTest {};
|
|
||||||
|
|
||||||
TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
|
||||||
HybridLayerNormLstmBlackBoxTestInt8) {
|
HybridLayerNormLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
@ -2104,24 +2057,24 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
|||||||
{n_cell}, // output_layer_norm_coefficient tensor
|
{n_cell}, // output_layer_norm_coefficient tensor
|
||||||
},
|
},
|
||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/true, GetParam());
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
// Goldens are results using FLOAT32 inference.
|
// Verify the final output.
|
||||||
lstm_golden_output_ = {{
|
lstm_golden_output_ = {
|
||||||
|
{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0212971, 0.140816, 0.0112733, // seq 0
|
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
||||||
0.0132302, 0.152308, 0.0346313, // seq 1
|
0.0130806509, 0.152660668, 0.0347516984, // seq 1
|
||||||
-0.0123688, 0.16579, 0.0893078, // seq 2
|
-0.0124010444, 0.166042402, 0.0898982584, // seq 2
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||||
-0.0226351, 0.0916948, 0.0769176, // seq 0
|
-0.0228835996, 0.0917588323, 0.0778886303, // seq 0
|
||||||
-0.0269967, 0.149708, 0.0941492, // seq 1
|
-0.0275101066, 0.148769245, 0.0938384682, // seq 1
|
||||||
-0.0103429, 0.173016, 0.0720509, // seq 2
|
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
/*tolerance=*/1e-3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class LSTMIntegerOpModel : public SingleOpModel {
|
class LSTMIntegerOpModel : public SingleOpModel {
|
||||||
@ -3358,22 +3311,5 @@ TEST(LSTMOpModel, InvalidTypeTest) {
|
|||||||
"");
|
"");
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define QUANTIZE_PARAMETER_TEST(test) \
|
|
||||||
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::Bool())
|
|
||||||
|
|
||||||
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
|
|
||||||
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test);
|
|
||||||
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmInt8Test);
|
|
||||||
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(
|
|
||||||
NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
|
|
||||||
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test);
|
|
||||||
#undef QUANTIZE_PARAMETER_TEST
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -43,7 +43,6 @@ struct OpData {
|
|||||||
int effective_scale_1_b;
|
int effective_scale_1_b;
|
||||||
int32 effective_scale_2_a;
|
int32 effective_scale_2_a;
|
||||||
int effective_scale_2_b;
|
int effective_scale_2_b;
|
||||||
bool compute_row_sums = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -62,8 +61,8 @@ constexpr int kOutputTensor = 0;
|
|||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* op_data = new OpData();
|
||||||
op_data->float_weights_time_initialized = false;
|
op_data->float_weights_time_initialized = false;
|
||||||
// Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise.
|
// Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise.
|
||||||
context->AddTensors(context, /*tensors_to_add=*/6,
|
context->AddTensors(context, /*tensors_to_add=*/4,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
@ -131,7 +130,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Resize scratch.
|
// Resize scratch.
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(6);
|
node->temporaries = TfLiteIntArrayCreate(4);
|
||||||
} else if (is_full_integer) {
|
} else if (is_full_integer) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(2);
|
node->temporaries = TfLiteIntArrayCreate(2);
|
||||||
} else {
|
} else {
|
||||||
@ -157,7 +156,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scratch_size_array));
|
scratch_size_array));
|
||||||
|
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
op_data->compute_row_sums = true;
|
|
||||||
// Tell interpreter to allocate temporary tensors to store quantized values
|
// Tell interpreter to allocate temporary tensors to store quantized values
|
||||||
// of input tensors.
|
// of input tensors.
|
||||||
node->temporaries->data[1] = scratch_tensor_index + 1;
|
node->temporaries->data[1] = scratch_tensor_index + 1;
|
||||||
@ -197,30 +195,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->ResizeTensor(context, float_weights_time,
|
context->ResizeTensor(context, float_weights_time,
|
||||||
float_weights_time_size));
|
float_weights_time_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
node->temporaries->data[4] = scratch_tensor_index + 4;
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
|
||||||
zero_points->type = kTfLiteFloat32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
int zero_points_dims[1] = {batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = zero_points_dims[0];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
node->temporaries->data[5] = scratch_tensor_index + 5;
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
|
||||||
row_sums->type = kTfLiteFloat32;
|
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int row_sums_dims[1] = {num_filters};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
|
|
||||||
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
|
|
||||||
row_sums_size->data[0] = row_sums_dims[0];
|
|
||||||
TF_LITE_ENSURE_OK(
|
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (is_full_integer) {
|
if (is_full_integer) {
|
||||||
// Allocated one extra tensor.
|
// Allocated one extra tensor.
|
||||||
@ -293,8 +267,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, /*index=*/2);
|
GetTemporary(context, node, /*index=*/2);
|
||||||
TfLiteTensor* float_weights_time =
|
TfLiteTensor* float_weights_time =
|
||||||
GetTemporary(context, node, /*index=*/3);
|
GetTemporary(context, node, /*index=*/3);
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
|
||||||
// Dequantize weights time.
|
// Dequantize weights time.
|
||||||
// TODO(alanchiao): this dequantization initialization only needs to
|
// TODO(alanchiao): this dequantization initialization only needs to
|
||||||
// happen once per model and should theoretically be placed in either
|
// happen once per model and should theoretically be placed in either
|
||||||
@ -312,11 +285,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
op_data->float_weights_time_initialized = true;
|
op_data->float_weights_time_initialized = true;
|
||||||
}
|
}
|
||||||
|
reference_ops::EvalHybridSVDF(context, node, input, weights_feature,
|
||||||
reference_ops::EvalHybridSVDF(
|
float_weights_time, bias, params, scratch,
|
||||||
context, node, input, weights_feature, float_weights_time, bias,
|
scaling_factors, input_quantized,
|
||||||
params, scratch, scaling_factors, input_quantized, activation_state,
|
activation_state, output);
|
||||||
output, zero_points, row_sums, &op_data->compute_row_sums);
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
} else {
|
} else {
|
||||||
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||||
|
@ -131,8 +131,7 @@ class BaseSVDFOpModel : public SingleOpModel {
|
|||||||
BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
||||||
int rank,
|
int rank,
|
||||||
TensorType weights_feature_type = TensorType_FLOAT32,
|
TensorType weights_feature_type = TensorType_FLOAT32,
|
||||||
TensorType weights_time_type = TensorType_FLOAT32,
|
TensorType weights_time_type = TensorType_FLOAT32)
|
||||||
bool asymmetric_quantize_inputs = false)
|
|
||||||
: batches_(batches),
|
: batches_(batches),
|
||||||
units_(units),
|
units_(units),
|
||||||
input_size_(input_size),
|
input_size_(input_size),
|
||||||
@ -147,10 +146,9 @@ class BaseSVDFOpModel : public SingleOpModel {
|
|||||||
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
|
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
|
||||||
/*is_variable=*/true);
|
/*is_variable=*/true);
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
|
SetBuiltinOp(
|
||||||
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE,
|
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
|
||||||
asymmetric_quantize_inputs)
|
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
|
||||||
.Union());
|
|
||||||
BuildInterpreter({
|
BuildInterpreter({
|
||||||
{batches_, input_size_}, // input tensor
|
{batches_, input_size_}, // input tensor
|
||||||
{units_ * rank, input_size_}, // weights_feature tensor
|
{units_ * rank, input_size_}, // weights_feature tensor
|
||||||
@ -205,10 +203,9 @@ class SVDFOpModel : public BaseSVDFOpModel {
|
|||||||
class HybridSVDFOpModel : public BaseSVDFOpModel {
|
class HybridSVDFOpModel : public BaseSVDFOpModel {
|
||||||
public:
|
public:
|
||||||
HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
|
||||||
int rank, TensorType tensor_type,
|
int rank, TensorType tensor_type)
|
||||||
bool asymmetric_quantize_inputs)
|
|
||||||
: BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
|
: BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
|
||||||
tensor_type, tensor_type, asymmetric_quantize_inputs) {
|
tensor_type, tensor_type) {
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,7 +229,7 @@ class HybridSVDFOpModel : public BaseSVDFOpModel {
|
|||||||
TensorType tensor_type_;
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SVDFOpTest : public ::testing::TestWithParam<bool> {
|
class SVDFOpTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void VerifyGoldens(float golden_input[], float golden_output[],
|
void VerifyGoldens(float golden_input[], float golden_output[],
|
||||||
int golden_size, BaseSVDFOpModel* svdf,
|
int golden_size, BaseSVDFOpModel* svdf,
|
||||||
@ -265,9 +262,6 @@ class SVDFOpTest : public ::testing::TestWithParam<bool> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest,
|
|
||||||
::testing::ValuesIn({false, true}));
|
|
||||||
|
|
||||||
TEST_F(SVDFOpTest, BlackBoxTestRank1) {
|
TEST_F(SVDFOpTest, BlackBoxTestRank1) {
|
||||||
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/1);
|
/*memory_size=*/10, /*rank=*/1);
|
||||||
@ -331,10 +325,9 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) {
|
|||||||
&svdf);
|
&svdf);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
|
TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/1, TensorType_UINT8,
|
/*memory_size=*/10, /*rank=*/1, TensorType_UINT8);
|
||||||
GetParam());
|
|
||||||
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
||||||
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
||||||
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
||||||
@ -354,13 +347,12 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.004285);
|
/*tolerance=*/0.002945);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
|
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/2, TensorType_UINT8,
|
/*memory_size=*/10, /*rank=*/2, TensorType_UINT8);
|
||||||
GetParam());
|
|
||||||
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
||||||
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
||||||
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
||||||
@ -395,13 +387,12 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.007175);
|
/*tolerance=*/0.00625109);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
|
TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/1, TensorType_INT8,
|
/*memory_size=*/10, /*rank=*/1, TensorType_INT8);
|
||||||
GetParam());
|
|
||||||
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
|
||||||
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
0.22197971, 0.12416199, 0.27901134, 0.27557442,
|
||||||
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
0.3905206, -0.36137494, -0.06634006, -0.10640851});
|
||||||
@ -421,13 +412,12 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.004285);
|
/*tolerance=*/0.002945);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
|
TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
|
||||||
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
|
||||||
/*memory_size=*/10, /*rank=*/2, TensorType_INT8,
|
/*memory_size=*/10, /*rank=*/2, TensorType_INT8);
|
||||||
GetParam());
|
|
||||||
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
|
||||||
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
0.12416199, 0.15785322, 0.27901134, 0.3905206,
|
||||||
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
0.21931258, -0.36137494, -0.10640851, 0.31053296,
|
||||||
@ -462,7 +452,7 @@ TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
|
|||||||
|
|
||||||
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
|
||||||
&svdf,
|
&svdf,
|
||||||
/*tolerance=*/0.007175);
|
/*tolerance=*/0.00625109);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test case for full integer quantization of SVDF.
|
// Test case for full integer quantization of SVDF.
|
||||||
|
@ -33,7 +33,6 @@ struct OpData {
|
|||||||
bool is_layer_norm_lstm;
|
bool is_layer_norm_lstm;
|
||||||
// The scratch tensor index.
|
// The scratch tensor index.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
bool compute_row_sums = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Input Tensors of size {max_time, n_batch, n_input}
|
// Input Tensors of size {max_time, n_batch, n_input}
|
||||||
@ -93,9 +92,7 @@ enum TemporaryTensor {
|
|||||||
kProductScalingFactors = 5,
|
kProductScalingFactors = 5,
|
||||||
kRecoveredCellWeights = 6,
|
kRecoveredCellWeights = 6,
|
||||||
kAccumScratch = 7,
|
kAccumScratch = 7,
|
||||||
kZeroPoints = 8,
|
kNumTemporaryTensors
|
||||||
kRowSums = 9,
|
|
||||||
kNumTemporaryTensors = 10
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
@ -411,7 +408,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
scratch_buffer_size));
|
scratch_buffer_size));
|
||||||
|
|
||||||
if (IsHybridOp(input, input_to_output_weights)) {
|
if (IsHybridOp(input, input_to_output_weights)) {
|
||||||
op_data->compute_row_sums = true;
|
|
||||||
// Allocate temporary tensors to store quantized values of input,
|
// Allocate temporary tensors to store quantized values of input,
|
||||||
// activation_state and cell_state tensors.
|
// activation_state and cell_state tensors.
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
@ -519,34 +515,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints;
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
|
||||||
zero_points->type = kTfLiteFloat32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = n_batch;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
|
||||||
row_sums->type = kTfLiteInt32;
|
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int row_sums_rows = use_cifg ? 6 : 8;
|
|
||||||
const TfLiteTensor* projection_weights =
|
|
||||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
|
||||||
if (projection_weights != nullptr) {
|
|
||||||
row_sums_rows += ceil(n_output / n_cell);
|
|
||||||
}
|
|
||||||
int row_sums_dims[2] = {row_sums_rows, n_cell};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
|
||||||
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
row_sums_size->data[0] = row_sums_dims[0];
|
|
||||||
row_sums_size->data[1] = row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(
|
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -632,7 +600,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
lstm_params.activation = params->activation;
|
lstm_params.activation = params->activation;
|
||||||
lstm_params.cell_clip = params->cell_clip;
|
lstm_params.cell_clip = params->cell_clip;
|
||||||
lstm_params.proj_clip = params->proj_clip;
|
lstm_params.proj_clip = params->proj_clip;
|
||||||
lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
|
|
||||||
|
|
||||||
switch (input_to_output_weights->type) {
|
switch (input_to_output_weights->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
@ -656,7 +623,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||||
TfLiteTensor* activation_state_quantized =
|
TfLiteTensor* activation_state_quantized =
|
||||||
GetTemporary(context, node, /*index=*/2);
|
GetTemporary(context, node, /*index=*/2);
|
||||||
@ -669,10 +635,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTemporary(context, node, /*index=*/6);
|
GetTemporary(context, node, /*index=*/6);
|
||||||
TfLiteTensor* accum_scratch =
|
TfLiteTensor* accum_scratch =
|
||||||
GetTemporary(context, node, /*index=*/kAccumScratch);
|
GetTemporary(context, node, /*index=*/kAccumScratch);
|
||||||
TfLiteTensor* zero_points =
|
|
||||||
GetTemporary(context, node, /*index=*/kZeroPoints);
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/kRowSums);
|
|
||||||
const int row_sums_size = row_sums->dims->data[0];
|
|
||||||
return lstm_eval::EvalHybrid(
|
return lstm_eval::EvalHybrid(
|
||||||
input, input_to_input_weights, input_to_forget_weights,
|
input, input_to_input_weights, input_to_forget_weights,
|
||||||
input_to_cell_weights, input_to_output_weights,
|
input_to_cell_weights, input_to_output_weights,
|
||||||
@ -692,9 +654,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||||
cell_state_quantized, activation_state, cell_state, accum_scratch,
|
cell_state_quantized, activation_state, cell_state, accum_scratch,
|
||||||
output, zero_points, row_sums, row_sums_size,
|
output, CpuBackendContext::GetFromContext(context));
|
||||||
&op_data->compute_row_sums,
|
|
||||||
CpuBackendContext::GetFromContext(context));
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d is not currently supported.",
|
context->ReportError(context, "Type %d is not currently supported.",
|
||||||
|
@ -38,8 +38,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
float proj_clip,
|
float proj_clip,
|
||||||
const std::vector<std::vector<int>>& input_shapes,
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
const TensorType& weights_type = TensorType_FLOAT32,
|
const TensorType& weights_type = TensorType_FLOAT32,
|
||||||
bool is_layer_norm = false,
|
bool is_layer_norm = false)
|
||||||
bool asymmetric_quantize_inputs = false)
|
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_cell_(n_cell),
|
n_cell_(n_cell),
|
||||||
@ -132,7 +131,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
||||||
CreateUnidirectionalSequenceLSTMOptions(
|
CreateUnidirectionalSequenceLSTMOptions(
|
||||||
builder_, ActivationFunctionType_TANH, cell_clip,
|
builder_, ActivationFunctionType_TANH, cell_clip,
|
||||||
proj_clip, time_major, asymmetric_quantize_inputs)
|
proj_clip, time_major)
|
||||||
.Union());
|
.Union());
|
||||||
BuildInterpreter(input_shapes);
|
BuildInterpreter(input_shapes);
|
||||||
}
|
}
|
||||||
@ -293,12 +292,11 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
bool time_major, bool use_cifg, bool use_peephole,
|
bool time_major, bool use_cifg, bool use_peephole,
|
||||||
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
||||||
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
||||||
TensorType tensor_type, bool asymmetric_quantize_inputs)
|
TensorType tensor_type)
|
||||||
: UnidirectionalLSTMOpModel(
|
: UnidirectionalLSTMOpModel(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
||||||
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
||||||
cell_clip, proj_clip, input_shapes, tensor_type, false,
|
cell_clip, proj_clip, input_shapes, tensor_type) {
|
||||||
asymmetric_quantize_inputs) {
|
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,7 +360,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
TensorType tensor_type_;
|
TensorType tensor_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseUnidirectionalLstmTest : public ::testing::TestWithParam<bool> {
|
class BaseUnidirectionalLstmTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
// Weights of the LSTM model. Some are optional.
|
// Weights of the LSTM model. Some are optional.
|
||||||
std::vector<float> input_to_input_weights_;
|
std::vector<float> input_to_input_weights_;
|
||||||
@ -628,7 +626,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
/*time_major=*/false);
|
/*time_major=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -670,7 +668,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8, GetParam());
|
TensorType_UINT8);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -691,7 +689,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -733,7 +731,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8, GetParam());
|
TensorType_INT8);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -864,7 +862,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -886,6 +884,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||||
{n_cell, n_input}, // input_to_output_weight tensor
|
{n_cell, n_input}, // input_to_output_weight tensor
|
||||||
|
|
||||||
{0, 0}, // recurrent_to_input_weight tensor
|
{0, 0}, // recurrent_to_input_weight tensor
|
||||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||||
@ -906,7 +905,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8, GetParam());
|
TensorType_UINT8);
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
@ -926,7 +925,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -969,7 +968,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8, GetParam());
|
TensorType_INT8);
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
@ -1656,16 +1655,14 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestUint8) {
|
HybridLstmBlackBoxTestUint8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
const int n_output = 16;
|
const int n_output = 16;
|
||||||
const int sequence_length = 4;
|
const int sequence_length = 4;
|
||||||
if (GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
HybridUnidirectionalLSTMOpModel lstm(
|
HybridUnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
@ -1700,7 +1697,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8, GetParam());
|
TensorType_UINT8);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -1726,11 +1723,8 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||||
HybridLstmBlackBoxTestInt8) {
|
HybridLstmBlackBoxTestInt8) {
|
||||||
if (GetParam()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1771,7 +1765,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|||||||
{n_batch, n_output}, // activation_state tensor
|
{n_batch, n_output}, // activation_state tensor
|
||||||
{n_batch, n_cell}, // cell_state tensor
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8, GetParam());
|
TensorType_INT8);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -2743,14 +2737,5 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define QUANTIZE_PARAMETER_TEST(test) \
|
|
||||||
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true}));
|
|
||||||
|
|
||||||
QUANTIZE_PARAMETER_TEST(
|
|
||||||
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(
|
|
||||||
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
|
|
||||||
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest);
|
|
||||||
#undef QUANTIZE_PARAMETER_TEST
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -26,15 +26,6 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace unidirectional_sequence_rnn {
|
namespace unidirectional_sequence_rnn {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
int scratch_tensor_index;
|
|
||||||
bool compute_row_sums = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Input tensors.
|
// Input tensors.
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
constexpr int kWeightsTensor = 1;
|
constexpr int kWeightsTensor = 1;
|
||||||
@ -46,14 +37,13 @@ constexpr int kHiddenStateTensor = 4;
|
|||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData();
|
auto* scratch_tensor_index = new int;
|
||||||
context->AddTensors(context, /*tensors_to_add=*/6,
|
context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
|
||||||
&op_data->scratch_tensor_index);
|
return scratch_tensor_index;
|
||||||
return op_data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<int*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
@ -106,11 +96,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Allocate temporary tensors to store quantized values of input and
|
// Allocate temporary tensors to store quantized values of input and
|
||||||
// hidden_state tensors.
|
// hidden_state tensors.
|
||||||
if (is_hybrid) {
|
if (is_hybrid) {
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||||
op_data->compute_row_sums = true;
|
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
node->temporaries = TfLiteIntArrayCreate(6);
|
node->temporaries = TfLiteIntArrayCreate(3);
|
||||||
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
node->temporaries->data[0] = *scratch_tensor_index;
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||||
input_quantized->type = input_weights->type;
|
input_quantized->type = input_weights->type;
|
||||||
input_quantized->allocation_type = kTfLiteArenaRw;
|
input_quantized->allocation_type = kTfLiteArenaRw;
|
||||||
@ -119,7 +108,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
|
||||||
input_quantized_size));
|
input_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
node->temporaries->data[1] = *scratch_tensor_index + 1;
|
||||||
TfLiteTensor* hidden_state_quantized =
|
TfLiteTensor* hidden_state_quantized =
|
||||||
GetTemporary(context, node, /*index=*/1);
|
GetTemporary(context, node, /*index=*/1);
|
||||||
hidden_state_quantized->type = input_weights->type;
|
hidden_state_quantized->type = input_weights->type;
|
||||||
@ -132,7 +121,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
context->ResizeTensor(context, hidden_state_quantized,
|
context->ResizeTensor(context, hidden_state_quantized,
|
||||||
hidden_state_quantized_size));
|
hidden_state_quantized_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
node->temporaries->data[2] = *scratch_tensor_index + 2;
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
|
||||||
scaling_factors->type = kTfLiteFloat32;
|
scaling_factors->type = kTfLiteFloat32;
|
||||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||||
@ -143,42 +132,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||||
scaling_factors_size));
|
scaling_factors_size));
|
||||||
}
|
}
|
||||||
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3);
|
|
||||||
accum_scratch->type = kTfLiteInt32;
|
|
||||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
|
||||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
|
||||||
accum_scratch_dims)) {
|
|
||||||
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
|
||||||
accum_scratch_size->data[0] = accum_scratch_dims[0];
|
|
||||||
accum_scratch_size->data[1] = accum_scratch_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
|
||||||
accum_scratch_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4);
|
|
||||||
zero_points->type = kTfLiteInt32;
|
|
||||||
zero_points->allocation_type = kTfLiteArenaRw;
|
|
||||||
int zero_points_dims[1] = {batch_size};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
|
||||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
|
||||||
zero_points_size->data[0] = batch_size;
|
|
||||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
|
||||||
zero_points_size));
|
|
||||||
}
|
|
||||||
node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5);
|
|
||||||
row_sums->type = kTfLiteInt32;
|
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
|
||||||
int row_sums_dims[2] = {2, num_units};
|
|
||||||
if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
|
|
||||||
TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
|
|
||||||
row_sums_size->data[0] = row_sums_dims[0];
|
|
||||||
row_sums_size->data[1] = row_sums_dims[1];
|
|
||||||
TF_LITE_ENSURE_OK(
|
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -249,9 +202,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
|
const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
|
||||||
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
|
const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
|
||||||
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
|
TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
|
||||||
TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points,
|
TfLiteTensor* hidden_state, TfLiteTensor* output) {
|
||||||
TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
|
|
||||||
bool* compute_row_sums) {
|
|
||||||
const bool time_major = params->time_major;
|
const bool time_major = params->time_major;
|
||||||
const int batch_size =
|
const int batch_size =
|
||||||
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
(time_major) ? input->dims->data[1] : input->dims->data[0];
|
||||||
@ -276,14 +227,6 @@ TfLiteStatus EvalHybrid(
|
|||||||
float input_weights_scale = input_weights->params.scale;
|
float input_weights_scale = input_weights->params.scale;
|
||||||
float recurrent_weights_scale = recurrent_weights->params.scale;
|
float recurrent_weights_scale = recurrent_weights->params.scale;
|
||||||
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
|
||||||
int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
|
|
||||||
int32_t* zero_points_ptr = nullptr;
|
|
||||||
int32_t* row_sums_ptr = nullptr;
|
|
||||||
|
|
||||||
if (params->asymmetric_quantize_inputs) {
|
|
||||||
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
|
||||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
// Initialize the pointer to hidden state.
|
// Initialize the pointer to hidden state.
|
||||||
@ -301,9 +244,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
|
||||||
num_units, batch_size, num_units, params->activation,
|
num_units, batch_size, num_units, params->activation,
|
||||||
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
|
||||||
hidden_state_ptr_batch, output_ptr_batch,
|
hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr,
|
|
||||||
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For each batch
|
// For each batch
|
||||||
@ -318,14 +259,13 @@ TfLiteStatus EvalHybrid(
|
|||||||
s * input_size;
|
s * input_size;
|
||||||
float* output_ptr_batch = GetTensorData<float>(output) +
|
float* output_ptr_batch = GetTensorData<float>(output) +
|
||||||
b * num_units * max_time + s * num_units;
|
b * num_units * max_time + s * num_units;
|
||||||
|
|
||||||
kernel_utils::RnnBatchStep(
|
kernel_utils::RnnBatchStep(
|
||||||
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
input_ptr_batch, input_weights_ptr, input_weights_scale,
|
||||||
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
|
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
|
||||||
input_size, num_units, /*batch_size=*/1, num_units,
|
input_size, num_units, /*batch_size=*/1, num_units,
|
||||||
params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
|
params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
|
||||||
scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch,
|
scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch);
|
||||||
params->asymmetric_quantize_inputs, zero_points_ptr,
|
|
||||||
accum_scratch_ptr, row_sums_ptr, compute_row_sums);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -334,6 +274,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_weights =
|
const TfLiteTensor* recurrent_weights =
|
||||||
@ -351,17 +292,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
// TODO(mirkov): implement eval with quantized inputs as well.
|
// TODO(mirkov): implement eval with quantized inputs as well.
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
|
||||||
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
|
||||||
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
|
||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
|
||||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, 3);
|
|
||||||
TfLiteTensor* zero_points = GetTemporary(context, node, 4);
|
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, 5);
|
|
||||||
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
|
||||||
input_quantized, hidden_state_quantized,
|
input_quantized, hidden_state_quantized,
|
||||||
scaling_factors, hidden_state, output, zero_points,
|
scaling_factors, hidden_state, output);
|
||||||
accum_scratch, row_sums, &op_data->compute_row_sums);
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d not currently supported.",
|
context->ReportError(context, "Type %d not currently supported.",
|
||||||
|
@ -174,8 +174,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
UnidirectionalRNNOpModel(
|
UnidirectionalRNNOpModel(
|
||||||
int batches, int sequence_len, int units, int size, bool time_major,
|
int batches, int sequence_len, int units, int size, bool time_major,
|
||||||
const TensorType& weights = TensorType_FLOAT32,
|
const TensorType& weights = TensorType_FLOAT32,
|
||||||
const TensorType& recurrent_weights = TensorType_FLOAT32,
|
const TensorType& recurrent_weights = TensorType_FLOAT32)
|
||||||
bool asymmetric_quantize_inputs = false)
|
|
||||||
: batches_(batches),
|
: batches_(batches),
|
||||||
sequence_len_(sequence_len),
|
sequence_len_(sequence_len),
|
||||||
units_(units),
|
units_(units),
|
||||||
@ -189,8 +188,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
|
|||||||
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
||||||
BuiltinOptions_SequenceRNNOptions,
|
BuiltinOptions_SequenceRNNOptions,
|
||||||
CreateSequenceRNNOptions(builder_, time_major,
|
CreateSequenceRNNOptions(builder_, time_major,
|
||||||
ActivationFunctionType_RELU,
|
ActivationFunctionType_RELU)
|
||||||
asymmetric_quantize_inputs)
|
|
||||||
.Union());
|
.Union());
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
BuildInterpreter({{sequence_len_, batches_, input_size_},
|
BuildInterpreter({{sequence_len_, batches_, input_size_},
|
||||||
@ -251,11 +249,9 @@ class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel {
|
|||||||
public:
|
public:
|
||||||
HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
|
HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
|
||||||
int size, bool time_major,
|
int size, bool time_major,
|
||||||
TensorType tensor_type,
|
TensorType tensor_type)
|
||||||
bool asymmetric_quantize_inputs)
|
|
||||||
: UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
|
: UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
|
||||||
tensor_type, tensor_type,
|
tensor_type, tensor_type) {
|
||||||
asymmetric_quantize_inputs) {
|
|
||||||
tensor_type_ = tensor_type;
|
tensor_type_ = tensor_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,14 +297,10 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
|
|||||||
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
||||||
}
|
}
|
||||||
|
|
||||||
class HybridUnidirectionalRNNOpModelOpTest
|
TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
|
||||||
: public ::testing::TestWithParam<bool> {};
|
|
||||||
|
|
||||||
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
|
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/false, TensorType_UINT8,
|
/*time_major=*/false, TensorType_UINT8);
|
||||||
GetParam());
|
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -331,11 +323,10 @@ TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) {
|
|||||||
expected, /*max_abs_error=*/0.013)));
|
expected, /*max_abs_error=*/0.013)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) {
|
TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/false, TensorType_INT8,
|
/*time_major=*/false, TensorType_INT8);
|
||||||
GetParam());
|
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -387,11 +378,10 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
|
|||||||
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
|
TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/true, TensorType_UINT8,
|
/*time_major=*/true, TensorType_UINT8);
|
||||||
GetParam());
|
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -418,11 +408,10 @@ TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) {
|
|||||||
expected, /*max_abs_error=*/0.013)));
|
expected, /*max_abs_error=*/0.013)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
|
TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
|
||||||
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
|
||||||
/*units=*/16, /*size=*/8,
|
/*units=*/16, /*size=*/8,
|
||||||
/*time_major=*/true, TensorType_INT8,
|
/*time_major=*/true, TensorType_INT8);
|
||||||
GetParam());
|
|
||||||
rnn.SetWeights(rnn_weights);
|
rnn.SetWeights(rnn_weights);
|
||||||
rnn.SetBias(rnn_bias);
|
rnn.SetBias(rnn_bias);
|
||||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||||
@ -449,9 +438,5 @@ TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) {
|
|||||||
expected, /*max_abs_error=*/0.013)));
|
expected, /*max_abs_error=*/0.013)));
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(HybridUnidirectionalRNNOpModelOpTest,
|
|
||||||
HybridUnidirectionalRNNOpModelOpTest,
|
|
||||||
::testing::ValuesIn({true, false}));
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -519,22 +519,17 @@ table LSHProjectionOptions {
|
|||||||
table SVDFOptions {
|
table SVDFOptions {
|
||||||
rank:int;
|
rank:int;
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
// For weights-only quantization, use asymmetric quantization for non
|
|
||||||
// constant inputs at evaluation time.
|
|
||||||
asymmetric_quantize_inputs:bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow RNNCell.
|
// An implementation of TensorFlow RNNCell.
|
||||||
table RNNOptions {
|
table RNNOptions {
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
asymmetric_quantize_inputs:bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow dynamic_rnn with RNNCell.
|
// An implementation of TensorFlow dynamic_rnn with RNNCell.
|
||||||
table SequenceRNNOptions {
|
table SequenceRNNOptions {
|
||||||
time_major:bool;
|
time_major:bool;
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
asymmetric_quantize_inputs:bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
|
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
|
||||||
@ -542,7 +537,6 @@ table BidirectionalSequenceRNNOptions {
|
|||||||
time_major:bool;
|
time_major:bool;
|
||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
merge_outputs: bool;
|
merge_outputs: bool;
|
||||||
asymmetric_quantize_inputs:bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum FullyConnectedOptionsWeightsFormat: byte {
|
enum FullyConnectedOptionsWeightsFormat: byte {
|
||||||
@ -562,11 +556,6 @@ table FullyConnectedOptions {
|
|||||||
// If set to true, then the number of dimension is preserved. Furthermore,
|
// If set to true, then the number of dimension is preserved. Furthermore,
|
||||||
// all but the last dimension of the input and output shapes will be equal.
|
// all but the last dimension of the input and output shapes will be equal.
|
||||||
keep_num_dims: bool;
|
keep_num_dims: bool;
|
||||||
|
|
||||||
// Parameters for FullyConnected version 7 or above.
|
|
||||||
// If set to true, then weights-only op will use asymmetric quantization for
|
|
||||||
// inputs.
|
|
||||||
asymmetric_quantize_inputs: bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
table SoftmaxOptions {
|
table SoftmaxOptions {
|
||||||
@ -615,9 +604,6 @@ table LSTMOptions {
|
|||||||
// Parameters for LSTM version 2 or above.
|
// Parameters for LSTM version 2 or above.
|
||||||
// Basic kernel is only supported in version 2 or above.
|
// Basic kernel is only supported in version 2 or above.
|
||||||
kernel_type: LSTMKernelType = FULL;
|
kernel_type: LSTMKernelType = FULL;
|
||||||
|
|
||||||
// Parameters for LSTM version 4 or above.
|
|
||||||
asymmetric_quantize_inputs: bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of TensorFlow dynamic_rnn with LSTMCell.
|
// An implementation of TensorFlow dynamic_rnn with LSTMCell.
|
||||||
@ -628,9 +614,6 @@ table UnidirectionalSequenceLSTMOptions {
|
|||||||
|
|
||||||
// If true then first dimension is sequence, otherwise batch.
|
// If true then first dimension is sequence, otherwise batch.
|
||||||
time_major:bool;
|
time_major:bool;
|
||||||
|
|
||||||
// Parameter for Unidirectional Sequence LSTM version 4.
|
|
||||||
asymmetric_quantize_inputs:bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
table BidirectionalSequenceLSTMOptions {
|
table BidirectionalSequenceLSTMOptions {
|
||||||
@ -647,9 +630,6 @@ table BidirectionalSequenceLSTMOptions {
|
|||||||
// Version 1 implementations assumed time_major to be true, so this default
|
// Version 1 implementations assumed time_major to be true, so this default
|
||||||
// value should never change.
|
// value should never change.
|
||||||
time_major: bool = true;
|
time_major: bool = true;
|
||||||
|
|
||||||
// Parameters for version 3 or above.
|
|
||||||
asymmetric_quantize_inputs:bool;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
table ResizeBilinearOptions {
|
table ResizeBilinearOptions {
|
||||||
|
@ -4216,11 +4216,9 @@ struct SVDFOptionsT : public flatbuffers::NativeTable {
|
|||||||
typedef SVDFOptions TableType;
|
typedef SVDFOptions TableType;
|
||||||
int32_t rank;
|
int32_t rank;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
SVDFOptionsT()
|
SVDFOptionsT()
|
||||||
: rank(0),
|
: rank(0),
|
||||||
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
fused_activation_function(tflite::ActivationFunctionType_NONE) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4228,8 +4226,7 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
typedef SVDFOptionsT NativeTableType;
|
typedef SVDFOptionsT NativeTableType;
|
||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_RANK = 4,
|
VT_RANK = 4,
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
VT_FUSED_ACTIVATION_FUNCTION = 6
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
|
|
||||||
};
|
};
|
||||||
int32_t rank() const {
|
int32_t rank() const {
|
||||||
return GetField<int32_t>(VT_RANK, 0);
|
return GetField<int32_t>(VT_RANK, 0);
|
||||||
@ -4237,14 +4234,10 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int32_t>(verifier, VT_RANK) &&
|
VerifyField<int32_t>(verifier, VT_RANK) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4261,9 +4254,6 @@ struct SVDFOptionsBuilder {
|
|||||||
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
||||||
fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4279,11 +4269,9 @@ struct SVDFOptionsBuilder {
|
|||||||
inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(
|
inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
int32_t rank = 0,
|
int32_t rank = 0,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
SVDFOptionsBuilder builder_(_fbb);
|
SVDFOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_rank(rank);
|
builder_.add_rank(rank);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
}
|
}
|
||||||
@ -4293,29 +4281,22 @@ flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBufferBuilde
|
|||||||
struct RNNOptionsT : public flatbuffers::NativeTable {
|
struct RNNOptionsT : public flatbuffers::NativeTable {
|
||||||
typedef RNNOptions TableType;
|
typedef RNNOptions TableType;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
RNNOptionsT()
|
RNNOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
typedef RNNOptionsT NativeTableType;
|
typedef RNNOptionsT NativeTableType;
|
||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 6
|
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4329,9 +4310,6 @@ struct RNNOptionsBuilder {
|
|||||||
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
||||||
fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4346,10 +4324,8 @@ struct RNNOptionsBuilder {
|
|||||||
|
|
||||||
inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
|
inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
RNNOptionsBuilder builder_(_fbb);
|
RNNOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
}
|
}
|
||||||
@ -4360,11 +4336,9 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable {
|
|||||||
typedef SequenceRNNOptions TableType;
|
typedef SequenceRNNOptions TableType;
|
||||||
bool time_major;
|
bool time_major;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
SequenceRNNOptionsT()
|
SequenceRNNOptionsT()
|
||||||
: time_major(false),
|
: time_major(false),
|
||||||
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
fused_activation_function(tflite::ActivationFunctionType_NONE) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4372,8 +4346,7 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
typedef SequenceRNNOptionsT NativeTableType;
|
typedef SequenceRNNOptionsT NativeTableType;
|
||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_TIME_MAJOR = 4,
|
VT_TIME_MAJOR = 4,
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
VT_FUSED_ACTIVATION_FUNCTION = 6
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
|
|
||||||
};
|
};
|
||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
@ -4381,14 +4354,10 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4405,9 +4374,6 @@ struct SequenceRNNOptionsBuilder {
|
|||||||
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
|
||||||
fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4423,10 +4389,8 @@ struct SequenceRNNOptionsBuilder {
|
|||||||
inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
|
inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
|
||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
bool time_major = false,
|
bool time_major = false,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
SequenceRNNOptionsBuilder builder_(_fbb);
|
SequenceRNNOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
@ -4439,12 +4403,10 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable {
|
|||||||
bool time_major;
|
bool time_major;
|
||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
bool merge_outputs;
|
bool merge_outputs;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
BidirectionalSequenceRNNOptionsT()
|
BidirectionalSequenceRNNOptionsT()
|
||||||
: time_major(false),
|
: time_major(false),
|
||||||
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
merge_outputs(false),
|
merge_outputs(false) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4453,8 +4415,7 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
|
|||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_TIME_MAJOR = 4,
|
VT_TIME_MAJOR = 4,
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
VT_FUSED_ACTIVATION_FUNCTION = 6,
|
||||||
VT_MERGE_OUTPUTS = 8,
|
VT_MERGE_OUTPUTS = 8
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
|
|
||||||
};
|
};
|
||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
@ -4465,15 +4426,11 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
|
|||||||
bool merge_outputs() const {
|
bool merge_outputs() const {
|
||||||
return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
|
return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4493,9 +4450,6 @@ struct BidirectionalSequenceRNNOptionsBuilder {
|
|||||||
void add_merge_outputs(bool merge_outputs) {
|
void add_merge_outputs(bool merge_outputs) {
|
||||||
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
|
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4512,10 +4466,8 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
|
|||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
bool time_major = false,
|
bool time_major = false,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
bool merge_outputs = false,
|
bool merge_outputs = false) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
|
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_merge_outputs(merge_outputs);
|
builder_.add_merge_outputs(merge_outputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
@ -4529,12 +4481,10 @@ struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
|
|||||||
tflite::ActivationFunctionType fused_activation_function;
|
tflite::ActivationFunctionType fused_activation_function;
|
||||||
tflite::FullyConnectedOptionsWeightsFormat weights_format;
|
tflite::FullyConnectedOptionsWeightsFormat weights_format;
|
||||||
bool keep_num_dims;
|
bool keep_num_dims;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
FullyConnectedOptionsT()
|
FullyConnectedOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT),
|
weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT),
|
||||||
keep_num_dims(false),
|
keep_num_dims(false) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4543,8 +4493,7 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
|
|||||||
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_WEIGHTS_FORMAT = 6,
|
VT_WEIGHTS_FORMAT = 6,
|
||||||
VT_KEEP_NUM_DIMS = 8,
|
VT_KEEP_NUM_DIMS = 8
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
|
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -4555,15 +4504,11 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
|
|||||||
bool keep_num_dims() const {
|
bool keep_num_dims() const {
|
||||||
return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0;
|
return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0;
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
|
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) &&
|
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -4583,9 +4528,6 @@ struct FullyConnectedOptionsBuilder {
|
|||||||
void add_keep_num_dims(bool keep_num_dims) {
|
void add_keep_num_dims(bool keep_num_dims) {
|
||||||
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0);
|
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast<uint8_t>(keep_num_dims), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -4602,10 +4544,8 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
|
|||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
|
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
|
||||||
bool keep_num_dims = false,
|
bool keep_num_dims = false) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
FullyConnectedOptionsBuilder builder_(_fbb);
|
FullyConnectedOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_keep_num_dims(keep_num_dims);
|
builder_.add_keep_num_dims(keep_num_dims);
|
||||||
builder_.add_weights_format(weights_format);
|
builder_.add_weights_format(weights_format);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
@ -4992,13 +4932,11 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
float cell_clip;
|
float cell_clip;
|
||||||
float proj_clip;
|
float proj_clip;
|
||||||
tflite::LSTMKernelType kernel_type;
|
tflite::LSTMKernelType kernel_type;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
LSTMOptionsT()
|
LSTMOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f),
|
proj_clip(0.0f),
|
||||||
kernel_type(tflite::LSTMKernelType_FULL),
|
kernel_type(tflite::LSTMKernelType_FULL) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -5008,8 +4946,7 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8,
|
VT_PROJ_CLIP = 8,
|
||||||
VT_KERNEL_TYPE = 10,
|
VT_KERNEL_TYPE = 10
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
|
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -5023,16 +4960,12 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
tflite::LSTMKernelType kernel_type() const {
|
tflite::LSTMKernelType kernel_type() const {
|
||||||
return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
|
return static_cast<tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
|
VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -5055,9 +4988,6 @@ struct LSTMOptionsBuilder {
|
|||||||
void add_kernel_type(tflite::LSTMKernelType kernel_type) {
|
void add_kernel_type(tflite::LSTMKernelType kernel_type) {
|
||||||
fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
|
fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -5075,12 +5005,10 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
|
|||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f,
|
float proj_clip = 0.0f,
|
||||||
tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL,
|
tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
LSTMOptionsBuilder builder_(_fbb);
|
LSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_kernel_type(kernel_type);
|
builder_.add_kernel_type(kernel_type);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
@ -5094,13 +5022,11 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
float cell_clip;
|
float cell_clip;
|
||||||
float proj_clip;
|
float proj_clip;
|
||||||
bool time_major;
|
bool time_major;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
UnidirectionalSequenceLSTMOptionsT()
|
UnidirectionalSequenceLSTMOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f),
|
proj_clip(0.0f),
|
||||||
time_major(false),
|
time_major(false) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -5110,8 +5036,7 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
|
|||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8,
|
VT_PROJ_CLIP = 8,
|
||||||
VT_TIME_MAJOR = 10,
|
VT_TIME_MAJOR = 10
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
|
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -5125,16 +5050,12 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
|
|||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -5157,9 +5078,6 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
|
|||||||
void add_time_major(bool time_major) {
|
void add_time_major(bool time_major) {
|
||||||
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
|
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -5177,12 +5095,10 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
|
|||||||
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
|
||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f,
|
float proj_clip = 0.0f,
|
||||||
bool time_major = false,
|
bool time_major = false) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
@ -5197,14 +5113,12 @@ struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
float proj_clip;
|
float proj_clip;
|
||||||
bool merge_outputs;
|
bool merge_outputs;
|
||||||
bool time_major;
|
bool time_major;
|
||||||
bool asymmetric_quantize_inputs;
|
|
||||||
BidirectionalSequenceLSTMOptionsT()
|
BidirectionalSequenceLSTMOptionsT()
|
||||||
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
: fused_activation_function(tflite::ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f),
|
proj_clip(0.0f),
|
||||||
merge_outputs(false),
|
merge_outputs(false),
|
||||||
time_major(true),
|
time_major(true) {
|
||||||
asymmetric_quantize_inputs(false) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -5215,8 +5129,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
|
|||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8,
|
VT_PROJ_CLIP = 8,
|
||||||
VT_MERGE_OUTPUTS = 10,
|
VT_MERGE_OUTPUTS = 10,
|
||||||
VT_TIME_MAJOR = 12,
|
VT_TIME_MAJOR = 12
|
||||||
VT_ASYMMETRIC_QUANTIZE_INPUTS = 14
|
|
||||||
};
|
};
|
||||||
tflite::ActivationFunctionType fused_activation_function() const {
|
tflite::ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -5233,9 +5146,6 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
|
|||||||
bool time_major() const {
|
bool time_major() const {
|
||||||
return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0;
|
return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0;
|
||||||
}
|
}
|
||||||
bool asymmetric_quantize_inputs() const {
|
|
||||||
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
|
|
||||||
}
|
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
@ -5243,7 +5153,6 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu
|
|||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) &&
|
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -5269,9 +5178,6 @@ struct BidirectionalSequenceLSTMOptionsBuilder {
|
|||||||
void add_time_major(bool time_major) {
|
void add_time_major(bool time_major) {
|
||||||
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1);
|
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 1);
|
||||||
}
|
}
|
||||||
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
|
|
||||||
fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
|
|
||||||
}
|
|
||||||
explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -5290,12 +5196,10 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
|
|||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f,
|
float proj_clip = 0.0f,
|
||||||
bool merge_outputs = false,
|
bool merge_outputs = false,
|
||||||
bool time_major = true,
|
bool time_major = true) {
|
||||||
bool asymmetric_quantize_inputs = false) {
|
|
||||||
BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
|
|
||||||
builder_.add_time_major(time_major);
|
builder_.add_time_major(time_major);
|
||||||
builder_.add_merge_outputs(merge_outputs);
|
builder_.add_merge_outputs(merge_outputs);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
@ -11130,7 +11034,6 @@ inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_
|
|||||||
(void)_resolver;
|
(void)_resolver;
|
||||||
{ auto _e = rank(); _o->rank = _e; }
|
{ auto _e = rank(); _o->rank = _e; }
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<SVDFOptions> SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11143,12 +11046,10 @@ inline flatbuffers::Offset<SVDFOptions> CreateSVDFOptions(flatbuffers::FlatBuffe
|
|||||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
auto _rank = _o->rank;
|
auto _rank = _o->rank;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateSVDFOptions(
|
return tflite::CreateSVDFOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_rank,
|
_rank,
|
||||||
_fused_activation_function,
|
_fused_activation_function);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11161,7 +11062,6 @@ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_fu
|
|||||||
(void)_o;
|
(void)_o;
|
||||||
(void)_resolver;
|
(void)_resolver;
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<RNNOptions> RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11173,11 +11073,9 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(flatbuffers::FlatBufferB
|
|||||||
(void)_o;
|
(void)_o;
|
||||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateRNNOptions(
|
return tflite::CreateRNNOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11191,7 +11089,6 @@ inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuff
|
|||||||
(void)_resolver;
|
(void)_resolver;
|
||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11204,12 +11101,10 @@ inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(flatbuff
|
|||||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateSequenceRNNOptions(
|
return tflite::CreateSequenceRNNOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_time_major,
|
_time_major,
|
||||||
_fused_activation_function,
|
_fused_activation_function);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11224,7 +11119,6 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp
|
|||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11238,13 +11132,11 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
|
|||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _merge_outputs = _o->merge_outputs;
|
auto _merge_outputs = _o->merge_outputs;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateBidirectionalSequenceRNNOptions(
|
return tflite::CreateBidirectionalSequenceRNNOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_time_major,
|
_time_major,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_merge_outputs,
|
_merge_outputs);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11259,7 +11151,6 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl
|
|||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
|
||||||
{ auto _e = weights_format(); _o->weights_format = _e; }
|
{ auto _e = weights_format(); _o->weights_format = _e; }
|
||||||
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
|
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11273,13 +11164,11 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(fl
|
|||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _weights_format = _o->weights_format;
|
auto _weights_format = _o->weights_format;
|
||||||
auto _keep_num_dims = _o->keep_num_dims;
|
auto _keep_num_dims = _o->keep_num_dims;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateFullyConnectedOptions(
|
return tflite::CreateFullyConnectedOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_weights_format,
|
_weights_format,
|
||||||
_keep_num_dims,
|
_keep_num_dims);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11463,7 +11352,6 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
|
|||||||
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
||||||
{ auto _e = kernel_type(); _o->kernel_type = _e; }
|
{ auto _e = kernel_type(); _o->kernel_type = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11478,14 +11366,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
|
|||||||
auto _cell_clip = _o->cell_clip;
|
auto _cell_clip = _o->cell_clip;
|
||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
auto _kernel_type = _o->kernel_type;
|
auto _kernel_type = _o->kernel_type;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateLSTMOptions(
|
return tflite::CreateLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip,
|
_proj_clip,
|
||||||
_kernel_type,
|
_kernel_type);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11501,7 +11387,6 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
|
|||||||
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
{ auto _e = cell_clip(); _o->cell_clip = _e; }
|
||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11516,14 +11401,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
|
|||||||
auto _cell_clip = _o->cell_clip;
|
auto _cell_clip = _o->cell_clip;
|
||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateUnidirectionalSequenceLSTMOptions(
|
return tflite::CreateUnidirectionalSequenceLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip,
|
_proj_clip,
|
||||||
_time_major,
|
_time_major);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
@ -11540,7 +11423,6 @@ inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTM
|
|||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
{ auto _e = proj_clip(); _o->proj_clip = _e; }
|
||||||
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
{ auto _e = merge_outputs(); _o->merge_outputs = _e; }
|
||||||
{ auto _e = time_major(); _o->time_major = _e; }
|
{ auto _e = time_major(); _o->time_major = _e; }
|
||||||
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -11556,15 +11438,13 @@ inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectional
|
|||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
auto _merge_outputs = _o->merge_outputs;
|
auto _merge_outputs = _o->merge_outputs;
|
||||||
auto _time_major = _o->time_major;
|
auto _time_major = _o->time_major;
|
||||||
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
|
|
||||||
return tflite::CreateBidirectionalSequenceLSTMOptions(
|
return tflite::CreateBidirectionalSequenceLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip,
|
_proj_clip,
|
||||||
_merge_outputs,
|
_merge_outputs,
|
||||||
_time_major,
|
_time_major);
|
||||||
_asymmetric_quantize_inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user