Add builtin sparse LSTM kernel.
PiperOrigin-RevId: 329562447 Change-Id: I5c407b513fbc86d21f6ea2d626da7b69dcd38bc7
This commit is contained in:
parent
b4ee2c4294
commit
8744e4b2b9
@ -1136,10 +1136,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
||||||
const int bw_row_sums_size = bw_row_sums->dims->data[0];
|
const int bw_row_sums_size = bw_row_sums->dims->data[0];
|
||||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
input, fw_input_to_input_weights,
|
||||||
fw_input_to_cell_weights, fw_input_to_output_weights,
|
/*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
|
||||||
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
|
/*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
|
||||||
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
|
/*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
|
||||||
|
/*input_to_output_weights_ledger*/ nullptr,
|
||||||
|
fw_recurrent_to_input_weights,
|
||||||
|
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||||
|
fw_recurrent_to_forget_weights,
|
||||||
|
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||||
|
fw_recurrent_to_cell_weights,
|
||||||
|
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||||
|
fw_recurrent_to_output_weights,
|
||||||
|
/*recurrent_to_output_weights_ledger*/ nullptr,
|
||||||
fw_cell_to_input_weights, fw_cell_to_forget_weights,
|
fw_cell_to_input_weights, fw_cell_to_forget_weights,
|
||||||
fw_cell_to_output_weights,
|
fw_cell_to_output_weights,
|
||||||
/*input_layer_norm_coefficients=*/nullptr,
|
/*input_layer_norm_coefficients=*/nullptr,
|
||||||
@ -1149,7 +1158,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
||||||
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
||||||
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
|
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
|
||||||
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
|
fw_output_gate_bias, fw_projection_weights,
|
||||||
|
/*projection_weights_ledger*/ nullptr, fw_projection_bias,
|
||||||
&lstm_params,
|
&lstm_params,
|
||||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||||
fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
||||||
@ -1167,10 +1177,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||||
|
|
||||||
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
||||||
bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
bw_input, bw_input_to_input_weights,
|
||||||
bw_input_to_cell_weights, bw_input_to_output_weights,
|
/*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
|
||||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
/*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
|
||||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
/*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
|
||||||
|
/*input_to_output_weights_ledger*/ nullptr,
|
||||||
|
bw_recurrent_to_input_weights,
|
||||||
|
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||||
|
bw_recurrent_to_forget_weights,
|
||||||
|
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||||
|
bw_recurrent_to_cell_weights,
|
||||||
|
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||||
|
bw_recurrent_to_output_weights,
|
||||||
|
/*recurrent_to_output_weights_ledger*/ nullptr,
|
||||||
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
||||||
bw_cell_to_output_weights,
|
bw_cell_to_output_weights,
|
||||||
/*input_layer_norm_coefficients=*/nullptr,
|
/*input_layer_norm_coefficients=*/nullptr,
|
||||||
@ -1180,7 +1199,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
||||||
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
||||||
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
|
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
|
||||||
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
|
bw_output_gate_bias, bw_projection_weights,
|
||||||
|
/*projection_weights_ledger*/ nullptr, bw_projection_bias,
|
||||||
&lstm_params,
|
&lstm_params,
|
||||||
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
||||||
bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
||||||
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <initializer_list>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ using ::testing::ElementsAreArray;
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
class DensifyOpModel : public SingleOpModel {
|
class DensifyOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
DensifyOpModel(const TensorData& input, std::initializer_list<T> input_data,
|
DensifyOpModel(const TensorData& input, const std::vector<T>& input_data,
|
||||||
int version = 1) {
|
int version = 1) {
|
||||||
input_ = AddConstSparseInput(input, input_data);
|
input_ = AddConstSparseInput(input, input_data);
|
||||||
output_ = AddOutput({input.type, input.shape});
|
output_ = AddOutput({input.type, input.shape});
|
||||||
@ -65,9 +64,8 @@ class DensifyOpModel : public SingleOpModel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST(DensifyOpTest, Float) {
|
TEST(DensifyOpTest, Float) {
|
||||||
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
|
std::vector<float> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||||
0, 0, 5, 0, 0, 7};
|
std::vector<float> sparse_values = {6, 9, 8, 5, 7};
|
||||||
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
|
|
||||||
TensorData input = {};
|
TensorData input = {};
|
||||||
input.type = TensorType_FLOAT32;
|
input.type = TensorType_FLOAT32;
|
||||||
input.shape = {3, 4};
|
input.shape = {3, 4};
|
||||||
@ -80,9 +78,8 @@ TEST(DensifyOpTest, Float) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(DensifyOpTest, Float3D) {
|
TEST(DensifyOpTest, Float3D) {
|
||||||
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
|
std::vector<float> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||||
0, 0, 5, 0, 0, 7};
|
std::vector<float> sparse_values = {6, 9, 8, 5, 7};
|
||||||
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
|
|
||||||
TensorData input = {};
|
TensorData input = {};
|
||||||
input.type = TensorType_FLOAT32;
|
input.type = TensorType_FLOAT32;
|
||||||
input.shape = {3, 2, 2};
|
input.shape = {3, 2, 2};
|
||||||
@ -95,9 +92,8 @@ TEST(DensifyOpTest, Float3D) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(DensifyOpTest, Int8) {
|
TEST(DensifyOpTest, Int8) {
|
||||||
std::initializer_list<int8_t> dense_values = {6, 0, 9, 8, 0, 0,
|
std::vector<int8_t> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||||
0, 0, 5, 0, 0, 7};
|
std::vector<int8_t> sparse_values = {6, 9, 8, 5, 7};
|
||||||
std::initializer_list<int8_t> sparse_values = {6, 9, 8, 5, 7};
|
|
||||||
TensorData input = {};
|
TensorData input = {};
|
||||||
input.type = TensorType_INT8;
|
input.type = TensorType_INT8;
|
||||||
input.shape = {3, 4};
|
input.shape = {3, 4};
|
||||||
|
@ -1144,7 +1144,7 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
|
|||||||
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
|
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
|
||||||
int batches, const TensorData& input,
|
int batches, const TensorData& input,
|
||||||
const TensorData& weights,
|
const TensorData& weights,
|
||||||
std::initializer_list<T> weights_data,
|
const std::vector<T>& weights_data,
|
||||||
int num_threads = 1)
|
int num_threads = 1)
|
||||||
: batches_(batches), units_(units) {
|
: batches_(batches), units_(units) {
|
||||||
int total_input_size = 1;
|
int total_input_size = 1;
|
||||||
|
@ -55,6 +55,10 @@ struct OpData {
|
|||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
||||||
bool compute_row_sums;
|
bool compute_row_sums;
|
||||||
|
|
||||||
|
// Only used for sparse hybrid lstm kernels.
|
||||||
|
int ledger_index;
|
||||||
|
bool ledger_initialized;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace full {
|
namespace full {
|
||||||
@ -77,6 +81,63 @@ enum HybridTemporaryTensor {
|
|||||||
kNumHybridTemporaryTensors = 12,
|
kNumHybridTemporaryTensors = 12,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
constexpr int kLedgersToAdd = 9;
|
||||||
|
constexpr int kInputToInputWeightsLedgerOffset = 0;
|
||||||
|
constexpr int kInputToForgetWeightsLedgerOffset = 1;
|
||||||
|
constexpr int kInputToCellWeightsLedgerOffset = 2;
|
||||||
|
constexpr int kInputToOutputWeightsLedgerOffset = 3;
|
||||||
|
constexpr int kRecurrentToInputWeightsLedgerOffset = 4;
|
||||||
|
constexpr int kRecurrentToForgetWeightsLedgerOffset = 5;
|
||||||
|
constexpr int kRecurrentToCellWeightsLedgerOffset = 6;
|
||||||
|
constexpr int kRecurrentToOutputWeightsLedgerOffset = 7;
|
||||||
|
constexpr int kProjectionWeightsLedgerOffset = 8;
|
||||||
|
|
||||||
|
TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context,
|
||||||
|
TfLiteTensor* ledger) {
|
||||||
|
ledger->type = kTfLiteUInt8;
|
||||||
|
ledger->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
|
if (sparsity == nullptr) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
|
||||||
|
ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
|
||||||
|
sparsity->dim_metadata[1].array_segments->size - 1;
|
||||||
|
return context->ResizeTensor(context, ledger, ledger_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) {
|
||||||
|
if (sparsity == nullptr) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto* array_segments = sparsity->dim_metadata[1].array_segments;
|
||||||
|
const auto* array_indices = sparsity->dim_metadata[1].array_indices;
|
||||||
|
uint8_t* output_data = GetTensorData<uint8_t>(ledger);
|
||||||
|
int output_data_ptr = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < array_segments->size - 1; i++) {
|
||||||
|
int row_start = array_segments->data[i];
|
||||||
|
int row_end = array_segments->data[i + 1];
|
||||||
|
if (row_end - row_start > UINT8_MAX) {
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
// Copy num of non-zero blocks in row i.
|
||||||
|
output_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
|
||||||
|
output_data_ptr++;
|
||||||
|
|
||||||
|
for (int j = row_start; j < row_end; j++) {
|
||||||
|
if (array_indices->data[j] > UINT8_MAX) {
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
// Copy indices of non-zero blocks in row i.
|
||||||
|
output_data[output_data_ptr] =
|
||||||
|
static_cast<uint8_t>(array_indices->data[j]);
|
||||||
|
output_data_ptr++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||||
TfLiteContext* context, TfLiteNode* node,
|
TfLiteContext* context, TfLiteNode* node,
|
||||||
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
|
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
|
||||||
@ -744,6 +805,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
// TODO(b/159066113): maybe just add the minimum required temp tensors?
|
// TODO(b/159066113): maybe just add the minimum required temp tensors?
|
||||||
context->AddTensors(context, kNumHybridTemporaryTensors,
|
context->AddTensors(context, kNumHybridTemporaryTensors,
|
||||||
&op_data->scratch_tensor_index);
|
&op_data->scratch_tensor_index);
|
||||||
|
// Tensors used for the sparse hybrid kernel.
|
||||||
|
context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd,
|
||||||
|
&op_data->ledger_index);
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1239,6 +1303,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// The weights are of consistent type, so it suffices to check one.
|
// The weights are of consistent type, so it suffices to check one.
|
||||||
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
|
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
|
||||||
|
|
||||||
|
const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr);
|
||||||
|
|
||||||
// The type of Integer LSTM.
|
// The type of Integer LSTM.
|
||||||
const int num_intermediate_tensors = node->intermediates->size;
|
const int num_intermediate_tensors = node->intermediates->size;
|
||||||
if (is_integer) {
|
if (is_integer) {
|
||||||
@ -1251,7 +1317,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
|
if (is_sparse_op) {
|
||||||
|
node->temporaries =
|
||||||
|
TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd);
|
||||||
|
} else {
|
||||||
|
node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
|
||||||
|
}
|
||||||
} else if (is_integer) {
|
} else if (is_integer) {
|
||||||
if (is_8x8_16) {
|
if (is_8x8_16) {
|
||||||
node->temporaries = TfLiteIntArrayCreate(6);
|
node->temporaries = TfLiteIntArrayCreate(6);
|
||||||
@ -1289,7 +1360,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (is_hybrid_op) {
|
if (is_hybrid_op) {
|
||||||
op_data->compute_row_sums = true;
|
if (!is_sparse_op) {
|
||||||
|
op_data->compute_row_sums = true;
|
||||||
|
}
|
||||||
// Allocate temporary tensors to store quantized values of input,
|
// Allocate temporary tensors to store quantized values of input,
|
||||||
// output_state and cell_state tensors.
|
// output_state and cell_state tensors.
|
||||||
node->temporaries->data[kInputQuantized] =
|
node->temporaries->data[kInputQuantized] =
|
||||||
@ -1454,6 +1527,125 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (is_sparse_op) {
|
||||||
|
op_data->ledger_initialized = false;
|
||||||
|
int offset = kNumHybridTemporaryTensors;
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kInputToInputWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* input_to_input_weights =
|
||||||
|
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||||
|
TfLiteTensor* input_to_input_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToInputWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(input_to_input_weights == nullptr
|
||||||
|
? nullptr
|
||||||
|
: input_to_input_weights->sparsity,
|
||||||
|
context, input_to_input_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kInputToForgetWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* input_to_forget_weights =
|
||||||
|
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||||
|
TfLiteTensor* input_to_forget_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToForgetWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(input_to_forget_weights->sparsity, context,
|
||||||
|
input_to_forget_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kInputToCellWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* input_to_cell_weights =
|
||||||
|
GetInput(context, node, kInputToCellWeightsTensor);
|
||||||
|
TfLiteTensor* input_to_cell_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToCellWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(input_to_cell_weights->sparsity, context,
|
||||||
|
input_to_cell_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kInputToOutputWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* input_to_output_weights =
|
||||||
|
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||||
|
TfLiteTensor* input_to_output_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToOutputWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(input_to_output_weights->sparsity, context,
|
||||||
|
input_to_output_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
|
||||||
|
context, node, kRecurrentToInputWeightsTensor);
|
||||||
|
TfLiteTensor* recurrent_to_input_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToInputWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(recurrent_to_input_weights == nullptr
|
||||||
|
? nullptr
|
||||||
|
: recurrent_to_input_weights->sparsity,
|
||||||
|
context, recurrent_to_input_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries
|
||||||
|
->data[offset + kRecurrentToForgetWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* recurrent_to_forget_weights =
|
||||||
|
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||||
|
TfLiteTensor* recurrent_to_forget_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToForgetWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(recurrent_to_forget_weights->sparsity,
|
||||||
|
context, recurrent_to_forget_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* recurrent_to_cell_weights =
|
||||||
|
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||||
|
TfLiteTensor* recurrent_to_cell_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToCellWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(recurrent_to_cell_weights->sparsity, context,
|
||||||
|
recurrent_to_cell_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries
|
||||||
|
->data[offset + kRecurrentToOutputWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* recurrent_to_output_weights =
|
||||||
|
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||||
|
TfLiteTensor* recurrent_to_output_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToOutputWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(recurrent_to_output_weights->sparsity,
|
||||||
|
context, recurrent_to_output_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
node->temporaries->data[offset + kProjectionWeightsLedgerOffset] =
|
||||||
|
op_data->ledger_index + kProjectionWeightsLedgerOffset;
|
||||||
|
const TfLiteTensor* projection_weights =
|
||||||
|
GetInput(context, node, kProjectionWeightsTensor);
|
||||||
|
TfLiteTensor* projection_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kProjectionWeightsLedgerOffset];
|
||||||
|
auto status = make_ledger(projection_weights->sparsity, context,
|
||||||
|
projection_weights_ledger);
|
||||||
|
if (status != kTfLiteOk) return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_integer) {
|
if (is_integer) {
|
||||||
@ -1624,14 +1816,116 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
const bool is_hybrid = (input->type == kTfLiteFloat32);
|
const bool is_hybrid = (input->type == kTfLiteFloat32);
|
||||||
|
const bool is_sparse = input_to_output_weights->sparsity != nullptr;
|
||||||
if (is_hybrid) {
|
if (is_hybrid) {
|
||||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||||
const int row_sums_size = row_sums->dims->data[0];
|
const int row_sums_size = row_sums->dims->data[0];
|
||||||
|
if (is_sparse) {
|
||||||
|
TfLiteTensor* input_to_input_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToInputWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* input_to_forget_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToForgetWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* input_to_cell_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToCellWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* input_to_output_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kInputToOutputWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* recurrent_to_input_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToInputWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* recurrent_to_forget_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToForgetWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* recurrent_to_cell_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToCellWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* recurrent_to_output_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kRecurrentToOutputWeightsLedgerOffset];
|
||||||
|
TfLiteTensor* projection_weights_ledger =
|
||||||
|
&context->tensors[op_data->ledger_index +
|
||||||
|
kProjectionWeightsLedgerOffset];
|
||||||
|
if (!op_data->ledger_initialized) {
|
||||||
|
copy_ledger(input_to_input_weights == nullptr
|
||||||
|
? nullptr
|
||||||
|
: input_to_input_weights->sparsity,
|
||||||
|
input_to_input_weights_ledger);
|
||||||
|
copy_ledger(input_to_forget_weights->sparsity,
|
||||||
|
input_to_forget_weights_ledger);
|
||||||
|
copy_ledger(input_to_cell_weights->sparsity,
|
||||||
|
input_to_cell_weights_ledger);
|
||||||
|
copy_ledger(input_to_output_weights->sparsity,
|
||||||
|
input_to_output_weights_ledger);
|
||||||
|
copy_ledger(recurrent_to_input_weights == nullptr
|
||||||
|
? nullptr
|
||||||
|
: recurrent_to_input_weights->sparsity,
|
||||||
|
recurrent_to_input_weights_ledger);
|
||||||
|
copy_ledger(recurrent_to_forget_weights->sparsity,
|
||||||
|
recurrent_to_forget_weights_ledger);
|
||||||
|
copy_ledger(recurrent_to_cell_weights->sparsity,
|
||||||
|
recurrent_to_cell_weights_ledger);
|
||||||
|
copy_ledger(recurrent_to_output_weights->sparsity,
|
||||||
|
recurrent_to_output_weights_ledger);
|
||||||
|
copy_ledger(projection_weights->sparsity,
|
||||||
|
projection_weights_ledger);
|
||||||
|
op_data->ledger_initialized = true;
|
||||||
|
}
|
||||||
|
return lstm_eval::EvalHybrid(
|
||||||
|
input, input_to_input_weights, input_to_input_weights_ledger,
|
||||||
|
input_to_forget_weights, input_to_forget_weights_ledger,
|
||||||
|
input_to_cell_weights, input_to_cell_weights_ledger,
|
||||||
|
input_to_output_weights, input_to_output_weights_ledger,
|
||||||
|
recurrent_to_input_weights, recurrent_to_input_weights_ledger,
|
||||||
|
recurrent_to_forget_weights, recurrent_to_forget_weights_ledger,
|
||||||
|
recurrent_to_cell_weights, recurrent_to_cell_weights_ledger,
|
||||||
|
recurrent_to_output_weights, recurrent_to_output_weights_ledger,
|
||||||
|
cell_to_input_weights, cell_to_forget_weights,
|
||||||
|
cell_to_output_weights, input_layer_norm_coefficients,
|
||||||
|
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
||||||
|
output_layer_norm_coefficients,
|
||||||
|
/*aux_input=*/nullptr,
|
||||||
|
/*aux_input_to_input_weights=*/nullptr,
|
||||||
|
/*aux_input_to_forget_weights=*/nullptr,
|
||||||
|
/*aux_input_to_cell_weights=*/nullptr,
|
||||||
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
|
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||||
|
projection_weights, projection_weights_ledger, projection_bias,
|
||||||
|
params,
|
||||||
|
/*forward_sequence=*/true, /*time_major=*/true,
|
||||||
|
/*output_offset=*/0, GetTemporary(context, node, kScratchBuffer),
|
||||||
|
GetTemporary(context, node, kInputScalingFactors),
|
||||||
|
/*aux_input_sf=*/nullptr,
|
||||||
|
GetTemporary(context, node, kOutputStateScalingFactors),
|
||||||
|
GetTemporary(context, node, kProductScalingFactors),
|
||||||
|
GetTemporary(context, node, kRecoveredCellWeights),
|
||||||
|
GetTemporary(context, node, kInputQuantized),
|
||||||
|
/*aux_input_quantized=*/nullptr,
|
||||||
|
GetTemporary(context, node, kOutputStateQuantized),
|
||||||
|
GetTemporary(context, node, kCellStateQuantized), output_state,
|
||||||
|
cell_state, GetTemporary(context, node, kAccumScratch), output,
|
||||||
|
GetTemporary(context, node, kInputZeroPoints),
|
||||||
|
/*aux_input_zp=*/nullptr,
|
||||||
|
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
|
||||||
|
row_sums_size, &op_data->compute_row_sums,
|
||||||
|
CpuBackendContext::GetFromContext(context));
|
||||||
|
}
|
||||||
return lstm_eval::EvalHybrid(
|
return lstm_eval::EvalHybrid(
|
||||||
input, input_to_input_weights, input_to_forget_weights,
|
input, input_to_input_weights,
|
||||||
input_to_cell_weights, input_to_output_weights,
|
/*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
|
||||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
/*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
|
||||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
/*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
|
||||||
|
/*input_to_output_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_input_weights,
|
||||||
|
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_forget_weights,
|
||||||
|
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_cell_weights,
|
||||||
|
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_output_weights,
|
||||||
|
/*recurrent_to_output_weights_ledger*/ nullptr,
|
||||||
cell_to_input_weights, cell_to_forget_weights,
|
cell_to_input_weights, cell_to_forget_weights,
|
||||||
cell_to_output_weights, input_layer_norm_coefficients,
|
cell_to_output_weights, input_layer_norm_coefficients,
|
||||||
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
||||||
@ -1641,7 +1935,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
/*aux_input_to_cell_weights=*/nullptr,
|
/*aux_input_to_cell_weights=*/nullptr,
|
||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||||
projection_weights, projection_bias, params,
|
projection_weights, /*projection_weights_ledger*/ nullptr,
|
||||||
|
projection_bias, params,
|
||||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||||
GetTemporary(context, node, kScratchBuffer),
|
GetTemporary(context, node, kScratchBuffer),
|
||||||
GetTemporary(context, node, kInputScalingFactors),
|
GetTemporary(context, node, kInputScalingFactors),
|
||||||
|
@ -312,6 +312,7 @@ void CalculateLstmGateHybrid(
|
|||||||
// Input and weights
|
// Input and weights
|
||||||
const int8_t* input, const float* input_sf, const int32_t* input_zp,
|
const int8_t* input, const float* input_sf, const int32_t* input_zp,
|
||||||
const int8_t* input_to_gate_weights,
|
const int8_t* input_to_gate_weights,
|
||||||
|
const uint8_t* input_to_gate_weights_ledger,
|
||||||
const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
|
const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
|
||||||
// Aux input and weights
|
// Aux input and weights
|
||||||
const int8_t* aux_input, const float* aux_input_sf,
|
const int8_t* aux_input, const float* aux_input_sf,
|
||||||
@ -321,6 +322,7 @@ void CalculateLstmGateHybrid(
|
|||||||
// Output state and weights
|
// Output state and weights
|
||||||
const int8_t* output_state, const float* output_state_sf,
|
const int8_t* output_state, const float* output_state_sf,
|
||||||
const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
|
const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
|
||||||
|
const uint8_t* recurrent_to_gate_weights_ledger,
|
||||||
const float recurrent_to_gate_weights_scale,
|
const float recurrent_to_gate_weights_scale,
|
||||||
int32_t* recurrent_to_gate_row_sums,
|
int32_t* recurrent_to_gate_row_sums,
|
||||||
// Cell state and weights (peephole LSTM)
|
// Cell state and weights (peephole LSTM)
|
||||||
@ -356,11 +358,22 @@ void CalculateLstmGateHybrid(
|
|||||||
// For each batch and cell: compute input_weight * input.
|
// For each batch and cell: compute input_weight * input.
|
||||||
// Skip if input is all zeros.
|
// Skip if input is all zeros.
|
||||||
if (!is_input_all_zeros) {
|
if (!is_input_all_zeros) {
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
if (input_to_gate_weights_ledger != nullptr) {
|
||||||
input_to_gate_weights, n_cell, n_input, input,
|
std::vector<float> scales(n_batch);
|
||||||
input_to_gate_weights_scale, input_sf, n_batch, gate,
|
for (int i = 0; i < n_batch; i++) {
|
||||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch,
|
scales[i] = input_to_gate_weights_scale * input_sf[i];
|
||||||
input_to_gate_row_sums, compute_row_sums, scratch0, context);
|
}
|
||||||
|
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
|
||||||
|
input, scales.data(), n_batch, gate);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
input_to_gate_weights, n_cell, n_input, input,
|
||||||
|
input_to_gate_weights_scale, input_sf, n_batch, gate,
|
||||||
|
/*per_channel_scale=*/nullptr, input_zp, accum_scratch,
|
||||||
|
input_to_gate_row_sums, compute_row_sums, scratch0, context);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||||
// Skip if auxiliary input is not available or all zeros.
|
// Skip if auxiliary input is not available or all zeros.
|
||||||
@ -374,11 +387,21 @@ void CalculateLstmGateHybrid(
|
|||||||
// For each batch and cell: compute recurrent_weight * output_state.
|
// For each batch and cell: compute recurrent_weight * output_state.
|
||||||
// Skip if output state is all zeros.
|
// Skip if output state is all zeros.
|
||||||
if (!is_output_state_all_zeros) {
|
if (!is_output_state_all_zeros) {
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
if (recurrent_to_gate_weights_ledger != nullptr) {
|
||||||
recurrent_to_gate_weights, n_cell, n_output, output_state,
|
std::vector<float> scales(n_batch);
|
||||||
recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
|
for (int i = 0; i < n_batch; i++) {
|
||||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
|
scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
|
||||||
recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
|
}
|
||||||
|
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
|
||||||
|
n_output, output_state, scales.data(), n_batch, gate);
|
||||||
|
} else {
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
recurrent_to_gate_weights, n_cell, n_output, output_state,
|
||||||
|
recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
|
||||||
|
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
|
||||||
|
recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
|
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
|
||||||
if (use_peephole) {
|
if (use_peephole) {
|
||||||
@ -422,11 +445,12 @@ void CalculateLstmGateHybrid(
|
|||||||
void CalculateLstmOutputHybrid(
|
void CalculateLstmOutputHybrid(
|
||||||
int n_batch, int n_cell, int n_output, const float* cell_state,
|
int n_batch, int n_cell, int n_output, const float* cell_state,
|
||||||
const float* output_gate, TfLiteFusedActivation activation,
|
const float* output_gate, TfLiteFusedActivation activation,
|
||||||
const int8_t* projection_weights, float projection_weights_scale,
|
const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
|
||||||
const float* projection_bias, const float proj_clip, float* output_state,
|
float projection_weights_scale, const float* projection_bias,
|
||||||
bool asymmetric_quantize_inputs, int32_t* projection_weights_row_sums,
|
const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
|
||||||
bool* compute_row_sums, CpuBackendContext* context, float* scratch0,
|
int32_t* projection_weights_row_sums, bool* compute_row_sums,
|
||||||
int8_t* scratch1, float* scratch2, int32_t* scratch3, int32_t* scratch4) {
|
CpuBackendContext* context, float* scratch0, int8_t* scratch1,
|
||||||
|
float* scratch2, int32_t* scratch3, int32_t* scratch4) {
|
||||||
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
|
||||||
activation, scratch0);
|
activation, scratch0);
|
||||||
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
|
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
|
||||||
@ -447,11 +471,21 @@ void CalculateLstmOutputHybrid(
|
|||||||
tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
|
tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
|
||||||
scratch2, scratch3,
|
scratch2, scratch3,
|
||||||
asymmetric_quantize_inputs);
|
asymmetric_quantize_inputs);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
if (projection_weights_ledger != nullptr) {
|
||||||
projection_weights, n_output, n_cell, scratch1,
|
std::vector<float> scales(n_batch);
|
||||||
projection_weights_scale, scratch2, n_batch, output_state,
|
for (int i = 0; i < n_batch; i++) {
|
||||||
/*per_channel_scale=*/nullptr, scratch3, scratch4,
|
scales[i] = projection_weights_scale * scratch2[i];
|
||||||
projection_weights_row_sums, compute_row_sums, scratch2, context);
|
}
|
||||||
|
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
projection_weights, projection_weights_ledger, n_output, n_cell,
|
||||||
|
scratch1, scales.data(), n_batch, output_state);
|
||||||
|
} else {
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
projection_weights, n_output, n_cell, scratch1,
|
||||||
|
projection_weights_scale, scratch2, n_batch, output_state,
|
||||||
|
/*per_channel_scale=*/nullptr, scratch3, scratch4,
|
||||||
|
projection_weights_row_sums, compute_row_sums, scratch2, context);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (proj_clip > 0.0f) {
|
if (proj_clip > 0.0f) {
|
||||||
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
|
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
|
||||||
@ -955,11 +989,16 @@ inline void LstmStepFloat(
|
|||||||
// output_ptr - size 'n_batch * output_batch_leading_dim'
|
// output_ptr - size 'n_batch * output_batch_leading_dim'
|
||||||
inline void LstmStepHybrid(
|
inline void LstmStepHybrid(
|
||||||
const float* input_ptr, const int8_t* input_to_input_weights_ptr,
|
const float* input_ptr, const int8_t* input_to_input_weights_ptr,
|
||||||
|
const uint8_t* input_to_input_weights_ledger_ptr,
|
||||||
float input_to_input_weights_scale,
|
float input_to_input_weights_scale,
|
||||||
const int8_t* input_to_forget_weights_ptr,
|
const int8_t* input_to_forget_weights_ptr,
|
||||||
|
const uint8_t* input_to_forget_weights_ledger_ptr,
|
||||||
float input_to_forget_weights_scale,
|
float input_to_forget_weights_scale,
|
||||||
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
const int8_t* input_to_cell_weights_ptr,
|
||||||
|
const uint8_t* input_to_cell_weights_ledger_ptr,
|
||||||
|
float input_to_cell_weights_scale,
|
||||||
const int8_t* input_to_output_weights_ptr,
|
const int8_t* input_to_output_weights_ptr,
|
||||||
|
const uint8_t* input_to_output_weights_ledger_ptr,
|
||||||
float input_to_output_weights_scale, const float* aux_input_ptr,
|
float input_to_output_weights_scale, const float* aux_input_ptr,
|
||||||
const int8_t* aux_input_to_input_weights_ptr,
|
const int8_t* aux_input_to_input_weights_ptr,
|
||||||
float aux_input_to_input_weights_scale,
|
float aux_input_to_input_weights_scale,
|
||||||
@ -970,12 +1009,16 @@ inline void LstmStepHybrid(
|
|||||||
const int8_t* aux_input_to_output_weights_ptr,
|
const int8_t* aux_input_to_output_weights_ptr,
|
||||||
float aux_input_to_output_weights_scale,
|
float aux_input_to_output_weights_scale,
|
||||||
const int8_t* recurrent_to_input_weights_ptr,
|
const int8_t* recurrent_to_input_weights_ptr,
|
||||||
|
const uint8_t* recurrent_to_input_weights_ledger_ptr,
|
||||||
float recurrent_to_input_weights_scale,
|
float recurrent_to_input_weights_scale,
|
||||||
const int8_t* recurrent_to_forget_weights_ptr,
|
const int8_t* recurrent_to_forget_weights_ptr,
|
||||||
|
const uint8_t* recurrent_to_forget_weights_ledger_ptr,
|
||||||
float recurrent_to_forget_weights_scale,
|
float recurrent_to_forget_weights_scale,
|
||||||
const int8_t* recurrent_to_cell_weights_ptr,
|
const int8_t* recurrent_to_cell_weights_ptr,
|
||||||
|
const uint8_t* recurrent_to_cell_weights_ledger_ptr,
|
||||||
float recurrent_to_cell_weights_scale,
|
float recurrent_to_cell_weights_scale,
|
||||||
const int8_t* recurrent_to_output_weights_ptr,
|
const int8_t* recurrent_to_output_weights_ptr,
|
||||||
|
const uint8_t* recurrent_to_output_weights_ledger_ptr,
|
||||||
float recurrent_to_output_weights_scale,
|
float recurrent_to_output_weights_scale,
|
||||||
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||||
const int8_t* cell_to_forget_weights_ptr,
|
const int8_t* cell_to_forget_weights_ptr,
|
||||||
@ -988,19 +1031,21 @@ inline void LstmStepHybrid(
|
|||||||
const float* output_layer_norm_coefficients_ptr,
|
const float* output_layer_norm_coefficients_ptr,
|
||||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||||
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
|
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
|
||||||
const int8_t* projection_weights_ptr, float projection_weights_scale,
|
const int8_t* projection_weights_ptr,
|
||||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
const uint8_t* projection_weights_ledger_ptr,
|
||||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
float projection_weights_scale, const float* projection_bias_ptr,
|
||||||
int output_batch_leading_dim, float* scratch0, float* scratch1,
|
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||||
float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf,
|
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||||
float* output_state_sf, float* scaling_factors_scratch,
|
float* scratch0, float* scratch1, float* scratch2, float* scratch3,
|
||||||
float* recovered_cell_weights, int8_t* quantized_input_ptr,
|
float* input_sf, float* aux_input_sf, float* output_state_sf,
|
||||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
float* scaling_factors_scratch, float* recovered_cell_weights,
|
||||||
int8_t* quantized_output_scratch, float* output_state_ptr,
|
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
|
||||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
|
||||||
int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp,
|
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
|
||||||
int32_t* row_sums, int row_sums_size, bool* compute_row_sums,
|
float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
|
||||||
bool asymmetric_quantize_inputs, CpuBackendContext* context) {
|
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
|
||||||
|
bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
||||||
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||||
// Since we have already checked that weights are all there or none, we
|
// Since we have already checked that weights are all there or none, we
|
||||||
// can check the existence of only one to the get the condition.
|
// can check the existence of only one to the get the condition.
|
||||||
@ -1106,11 +1151,12 @@ inline void LstmStepHybrid(
|
|||||||
// Calculate the input gate. (If not CIFG.)
|
// Calculate the input gate. (If not CIFG.)
|
||||||
CalculateLstmGateHybrid(
|
CalculateLstmGateHybrid(
|
||||||
quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
|
quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
|
||||||
input_to_input_weights_scale, input_to_input_row_sums,
|
input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
|
||||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||||
aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale,
|
aux_input_zp, aux_input_to_input_weights_ptr,
|
||||||
aux_input_to_input_row_sums, quantized_output_state_ptr,
|
aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
|
||||||
output_state_sf, output_state_zp, recurrent_to_input_weights_ptr,
|
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||||
|
recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
|
||||||
recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
|
recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
|
||||||
cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
||||||
input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
|
input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
|
||||||
@ -1122,11 +1168,12 @@ inline void LstmStepHybrid(
|
|||||||
// Calculate the forget gate.
|
// Calculate the forget gate.
|
||||||
CalculateLstmGateHybrid(
|
CalculateLstmGateHybrid(
|
||||||
quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
|
quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
|
||||||
input_to_forget_weights_scale, input_to_forget_row_sums,
|
input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
|
||||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||||
aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale,
|
aux_input_zp, aux_input_to_forget_weights_ptr,
|
||||||
aux_input_to_forget_row_sums, quantized_output_state_ptr, output_state_sf,
|
aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
|
||||||
output_state_zp, recurrent_to_forget_weights_ptr,
|
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||||
|
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
|
||||||
recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
|
recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
|
||||||
cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
||||||
forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
|
forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
|
||||||
@ -1137,11 +1184,12 @@ inline void LstmStepHybrid(
|
|||||||
// Calculate the cell update gate.
|
// Calculate the cell update gate.
|
||||||
CalculateLstmGateHybrid(
|
CalculateLstmGateHybrid(
|
||||||
quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
|
quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
|
||||||
input_to_cell_weights_scale, input_to_cell_row_sums,
|
input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
|
||||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||||
aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale,
|
aux_input_zp, aux_input_to_cell_weights_ptr,
|
||||||
aux_input_to_cell_row_sums, quantized_output_state_ptr, output_state_sf,
|
aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
|
||||||
output_state_zp, recurrent_to_cell_weights_ptr,
|
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||||
|
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
|
||||||
recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
|
recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
|
||||||
/*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
|
/*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
|
||||||
/*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
|
/*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
|
||||||
@ -1157,11 +1205,12 @@ inline void LstmStepHybrid(
|
|||||||
// Calculate the output gate.
|
// Calculate the output gate.
|
||||||
CalculateLstmGateHybrid(
|
CalculateLstmGateHybrid(
|
||||||
quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
|
quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
|
||||||
input_to_output_weights_scale, input_to_output_row_sums,
|
input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
|
||||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||||
aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale,
|
aux_input_zp, aux_input_to_output_weights_ptr,
|
||||||
aux_input_to_output_row_sums, quantized_output_state_ptr, output_state_sf,
|
aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
|
||||||
output_state_zp, recurrent_to_output_weights_ptr,
|
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||||
|
recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
|
||||||
recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
|
recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
|
||||||
cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
||||||
output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
|
output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
|
||||||
@ -1172,11 +1221,11 @@ inline void LstmStepHybrid(
|
|||||||
// Update the output state.
|
// Update the output state.
|
||||||
CalculateLstmOutputHybrid(
|
CalculateLstmOutputHybrid(
|
||||||
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
|
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
|
||||||
params->activation, projection_weights_ptr, projection_weights_scale,
|
params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
|
||||||
projection_bias_ptr, params->proj_clip, output_state_ptr,
|
projection_weights_scale, projection_bias_ptr, params->proj_clip,
|
||||||
asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums,
|
output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
|
||||||
context, scratch2, quantized_output_scratch, input_sf, input_zp,
|
compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
|
||||||
accum_scratch_ptr);
|
input_zp, accum_scratch_ptr);
|
||||||
// Copy output state to the output. Note that the output's rows may not be
|
// Copy output state to the output. Note that the output's rows may not be
|
||||||
// contiguous (output_batch_leading_dim != n_output).
|
// contiguous (output_batch_leading_dim != n_output).
|
||||||
for (int b = 0; b < n_batch; b++) {
|
for (int b = 0; b < n_batch; b++) {
|
||||||
@ -1829,13 +1878,21 @@ TfLiteStatus EvalFloat(
|
|||||||
|
|
||||||
TfLiteStatus EvalHybrid(
|
TfLiteStatus EvalHybrid(
|
||||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||||
|
const TfLiteTensor* input_to_input_weights_ledger,
|
||||||
const TfLiteTensor* input_to_forget_weights,
|
const TfLiteTensor* input_to_forget_weights,
|
||||||
|
const TfLiteTensor* input_to_forget_weights_ledger,
|
||||||
const TfLiteTensor* input_to_cell_weights,
|
const TfLiteTensor* input_to_cell_weights,
|
||||||
|
const TfLiteTensor* input_to_cell_weights_ledger,
|
||||||
const TfLiteTensor* input_to_output_weights,
|
const TfLiteTensor* input_to_output_weights,
|
||||||
|
const TfLiteTensor* input_to_output_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_input_weights,
|
const TfLiteTensor* recurrent_to_input_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_input_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_forget_weights,
|
const TfLiteTensor* recurrent_to_forget_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_forget_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_cell_weights,
|
const TfLiteTensor* recurrent_to_cell_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_cell_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_output_weights,
|
const TfLiteTensor* recurrent_to_output_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_output_weights_ledger,
|
||||||
const TfLiteTensor* cell_to_input_weights,
|
const TfLiteTensor* cell_to_input_weights,
|
||||||
const TfLiteTensor* cell_to_forget_weights,
|
const TfLiteTensor* cell_to_forget_weights,
|
||||||
const TfLiteTensor* cell_to_output_weights,
|
const TfLiteTensor* cell_to_output_weights,
|
||||||
@ -1850,9 +1907,11 @@ TfLiteStatus EvalHybrid(
|
|||||||
const TfLiteTensor* aux_input_to_output_weights,
|
const TfLiteTensor* aux_input_to_output_weights,
|
||||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||||
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
||||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
const TfLiteTensor* projection_weights,
|
||||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
const TfLiteTensor* projection_weights_ledger,
|
||||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||||
|
bool forward_sequence, bool time_major, int output_offset,
|
||||||
|
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||||
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
||||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||||
@ -1929,12 +1988,16 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<float>(output) + t_rel * output_step + output_offset;
|
GetTensorData<float>(output) + t_rel * output_step + output_offset;
|
||||||
LstmStepHybrid(
|
LstmStepHybrid(
|
||||||
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_input_weights_ledger),
|
||||||
GetTensorScale(input_to_input_weights),
|
GetTensorScale(input_to_input_weights),
|
||||||
GetTensorData<int8_t>(input_to_forget_weights),
|
GetTensorData<int8_t>(input_to_forget_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_forget_weights_ledger),
|
||||||
GetTensorScale(input_to_forget_weights),
|
GetTensorScale(input_to_forget_weights),
|
||||||
GetTensorData<int8_t>(input_to_cell_weights),
|
GetTensorData<int8_t>(input_to_cell_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_cell_weights_ledger),
|
||||||
GetTensorScale(input_to_cell_weights),
|
GetTensorScale(input_to_cell_weights),
|
||||||
GetTensorData<int8_t>(input_to_output_weights),
|
GetTensorData<int8_t>(input_to_output_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_output_weights_ledger),
|
||||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||||
GetTensorScale(aux_input_to_input_weights),
|
GetTensorScale(aux_input_to_input_weights),
|
||||||
@ -1945,12 +2008,16 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||||
GetTensorScale(aux_input_to_output_weights),
|
GetTensorScale(aux_input_to_output_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_input_weights),
|
GetTensorScale(recurrent_to_input_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_forget_weights),
|
GetTensorScale(recurrent_to_forget_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_cell_weights),
|
GetTensorScale(recurrent_to_cell_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_output_weights),
|
GetTensorScale(recurrent_to_output_weights),
|
||||||
GetTensorData<int8_t>(cell_to_input_weights),
|
GetTensorData<int8_t>(cell_to_input_weights),
|
||||||
GetTensorScale(cell_to_input_weights),
|
GetTensorScale(cell_to_input_weights),
|
||||||
@ -1967,6 +2034,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<float>(cell_gate_bias),
|
GetTensorData<float>(cell_gate_bias),
|
||||||
GetTensorData<float>(output_gate_bias),
|
GetTensorData<float>(output_gate_bias),
|
||||||
GetTensorData<int8_t>(projection_weights),
|
GetTensorData<int8_t>(projection_weights),
|
||||||
|
GetTensorData<uint8_t>(projection_weights_ledger),
|
||||||
GetTensorScale(projection_weights),
|
GetTensorScale(projection_weights),
|
||||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||||
@ -2018,12 +2086,16 @@ TfLiteStatus EvalHybrid(
|
|||||||
|
|
||||||
LstmStepHybrid(
|
LstmStepHybrid(
|
||||||
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_input_weights_ledger),
|
||||||
GetTensorScale(input_to_input_weights),
|
GetTensorScale(input_to_input_weights),
|
||||||
GetTensorData<int8_t>(input_to_forget_weights),
|
GetTensorData<int8_t>(input_to_forget_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_forget_weights_ledger),
|
||||||
GetTensorScale(input_to_forget_weights),
|
GetTensorScale(input_to_forget_weights),
|
||||||
GetTensorData<int8_t>(input_to_cell_weights),
|
GetTensorData<int8_t>(input_to_cell_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_cell_weights_ledger),
|
||||||
GetTensorScale(input_to_cell_weights),
|
GetTensorScale(input_to_cell_weights),
|
||||||
GetTensorData<int8_t>(input_to_output_weights),
|
GetTensorData<int8_t>(input_to_output_weights),
|
||||||
|
GetTensorData<uint8_t>(input_to_output_weights_ledger),
|
||||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||||
GetTensorScale(aux_input_to_input_weights),
|
GetTensorScale(aux_input_to_input_weights),
|
||||||
@ -2034,12 +2106,16 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||||
GetTensorScale(aux_input_to_output_weights),
|
GetTensorScale(aux_input_to_output_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_input_weights),
|
GetTensorScale(recurrent_to_input_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_forget_weights),
|
GetTensorScale(recurrent_to_forget_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_cell_weights),
|
GetTensorScale(recurrent_to_cell_weights),
|
||||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||||
|
GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
|
||||||
GetTensorScale(recurrent_to_output_weights),
|
GetTensorScale(recurrent_to_output_weights),
|
||||||
GetTensorData<int8_t>(cell_to_input_weights),
|
GetTensorData<int8_t>(cell_to_input_weights),
|
||||||
GetTensorScale(cell_to_input_weights),
|
GetTensorScale(cell_to_input_weights),
|
||||||
@ -2056,6 +2132,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<float>(cell_gate_bias),
|
GetTensorData<float>(cell_gate_bias),
|
||||||
GetTensorData<float>(output_gate_bias),
|
GetTensorData<float>(output_gate_bias),
|
||||||
GetTensorData<int8_t>(projection_weights),
|
GetTensorData<int8_t>(projection_weights),
|
||||||
|
GetTensorData<uint8_t>(projection_weights_ledger),
|
||||||
GetTensorScale(projection_weights),
|
GetTensorScale(projection_weights),
|
||||||
GetTensorData<float>(projection_bias), params,
|
GetTensorData<float>(projection_bias), params,
|
||||||
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
||||||
|
@ -125,13 +125,21 @@ TfLiteStatus EvalFloat(
|
|||||||
|
|
||||||
TfLiteStatus EvalHybrid(
|
TfLiteStatus EvalHybrid(
|
||||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||||
|
const TfLiteTensor* input_to_input_weights_ledger,
|
||||||
const TfLiteTensor* input_to_forget_weights,
|
const TfLiteTensor* input_to_forget_weights,
|
||||||
|
const TfLiteTensor* input_to_forget_weights_ledger,
|
||||||
const TfLiteTensor* input_to_cell_weights,
|
const TfLiteTensor* input_to_cell_weights,
|
||||||
|
const TfLiteTensor* input_to_cell_weights_ledger,
|
||||||
const TfLiteTensor* input_to_output_weights,
|
const TfLiteTensor* input_to_output_weights,
|
||||||
|
const TfLiteTensor* input_to_output_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_input_weights,
|
const TfLiteTensor* recurrent_to_input_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_input_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_forget_weights,
|
const TfLiteTensor* recurrent_to_forget_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_forget_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_cell_weights,
|
const TfLiteTensor* recurrent_to_cell_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_cell_weights_ledger,
|
||||||
const TfLiteTensor* recurrent_to_output_weights,
|
const TfLiteTensor* recurrent_to_output_weights,
|
||||||
|
const TfLiteTensor* recurrent_to_output_weights_ledger,
|
||||||
const TfLiteTensor* cell_to_input_weights,
|
const TfLiteTensor* cell_to_input_weights,
|
||||||
const TfLiteTensor* cell_to_forget_weights,
|
const TfLiteTensor* cell_to_forget_weights,
|
||||||
const TfLiteTensor* cell_to_output_weights,
|
const TfLiteTensor* cell_to_output_weights,
|
||||||
@ -146,9 +154,11 @@ TfLiteStatus EvalHybrid(
|
|||||||
const TfLiteTensor* aux_input_to_output_weights,
|
const TfLiteTensor* aux_input_to_output_weights,
|
||||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||||
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
||||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
const TfLiteTensor* projection_weights,
|
||||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
const TfLiteTensor* projection_weights_ledger,
|
||||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||||
|
bool forward_sequence, bool time_major, int output_offset,
|
||||||
|
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||||
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
||||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||||
|
@ -906,14 +906,14 @@ void TestOneHybridAsymmLSTM() {
|
|||||||
constexpr float kDefaultScale = 18.0;
|
constexpr float kDefaultScale = 18.0;
|
||||||
ops::builtin::lstm_eval::EvalHybrid(
|
ops::builtin::lstm_eval::EvalHybrid(
|
||||||
one_parameter.GetFloatInput(),
|
one_parameter.GetFloatInput(),
|
||||||
HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), nullptr,
|
||||||
HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale),
|
HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), nullptr,
|
||||||
/*cell_to_input_weights=*/nullptr,
|
/*cell_to_input_weights=*/nullptr,
|
||||||
/*cell_to_forget_weights=*/nullptr,
|
/*cell_to_forget_weights=*/nullptr,
|
||||||
/*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(),
|
/*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(),
|
||||||
@ -926,7 +926,7 @@ void TestOneHybridAsymmLSTM() {
|
|||||||
/*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(),
|
/*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(),
|
||||||
one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
|
one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
|
||||||
one_parameter.GetOutputBias(),
|
one_parameter.GetOutputBias(),
|
||||||
HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0),
|
HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), nullptr,
|
||||||
one_parameter.GetProjectionBias(), ¶m,
|
one_parameter.GetProjectionBias(), ¶m,
|
||||||
/*forward_sequence=*/true,
|
/*forward_sequence=*/true,
|
||||||
/*time_major=*/true,
|
/*time_major=*/true,
|
||||||
|
@ -2114,6 +2114,620 @@ TEST(LstmOpTest, InvalidTypes) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
class HybridSparseLSTMOpModel : public ::tflite::SingleOpModel {
|
||||||
|
public:
|
||||||
|
HybridSparseLSTMOpModel(
|
||||||
|
int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
|
||||||
|
bool use_peephole, bool use_projection_weights, bool use_projection_bias,
|
||||||
|
float cell_clip, float proj_clip,
|
||||||
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
|
const TensorData& input_weights_td,
|
||||||
|
const std::vector<float>& input_to_input_weights,
|
||||||
|
const std::vector<float>& input_to_forget_weights,
|
||||||
|
const std::vector<float>& input_to_cell_weights,
|
||||||
|
const std::vector<float>& input_to_output_weights,
|
||||||
|
const TensorData& recurrent_weights_td,
|
||||||
|
const std::vector<float>& recurrent_to_input_weights,
|
||||||
|
const std::vector<float>& recurrent_to_forget_weights,
|
||||||
|
const std::vector<float>& recurrent_to_cell_weights,
|
||||||
|
const std::vector<float>& recurrent_to_output_weights,
|
||||||
|
const ::tflite::TensorType& weight_type = ::tflite::TensorType_INT8)
|
||||||
|
: n_batch_(n_batch),
|
||||||
|
n_input_(n_input),
|
||||||
|
n_cell_(n_cell),
|
||||||
|
n_output_(n_output) {
|
||||||
|
input_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
|
||||||
|
if (use_cifg) {
|
||||||
|
input_to_input_weights_ = AddNullInput();
|
||||||
|
} else {
|
||||||
|
input_to_input_weights_ =
|
||||||
|
AddConstSparseInput(input_weights_td, input_to_input_weights, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
input_to_forget_weights_ =
|
||||||
|
AddConstSparseInput(input_weights_td, input_to_forget_weights, true);
|
||||||
|
|
||||||
|
input_to_cell_weights_ =
|
||||||
|
AddConstSparseInput(input_weights_td, input_to_cell_weights, true);
|
||||||
|
|
||||||
|
input_to_output_weights_ =
|
||||||
|
AddConstSparseInput(input_weights_td, input_to_output_weights, true);
|
||||||
|
|
||||||
|
if (use_cifg) {
|
||||||
|
recurrent_to_input_weights_ = AddNullInput();
|
||||||
|
} else {
|
||||||
|
recurrent_to_input_weights_ = AddConstSparseInput(
|
||||||
|
recurrent_weights_td, recurrent_to_input_weights, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
recurrent_to_forget_weights_ = AddConstSparseInput(
|
||||||
|
recurrent_weights_td, recurrent_to_forget_weights, true);
|
||||||
|
recurrent_to_cell_weights_ = AddConstSparseInput(
|
||||||
|
recurrent_weights_td, recurrent_to_cell_weights, true);
|
||||||
|
recurrent_to_output_weights_ = AddConstSparseInput(
|
||||||
|
recurrent_weights_td, recurrent_to_output_weights, true);
|
||||||
|
|
||||||
|
if (use_peephole) {
|
||||||
|
if (use_cifg) {
|
||||||
|
cell_to_input_weights_ = AddNullInput();
|
||||||
|
} else {
|
||||||
|
cell_to_input_weights_ = AddInput(weight_type);
|
||||||
|
}
|
||||||
|
cell_to_forget_weights_ = AddInput(weight_type);
|
||||||
|
cell_to_output_weights_ = AddInput(weight_type);
|
||||||
|
} else {
|
||||||
|
cell_to_input_weights_ = AddNullInput();
|
||||||
|
cell_to_forget_weights_ = AddNullInput();
|
||||||
|
cell_to_output_weights_ = AddNullInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_cifg) {
|
||||||
|
input_gate_bias_ = AddNullInput();
|
||||||
|
} else {
|
||||||
|
input_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
}
|
||||||
|
forget_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
cell_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
output_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
|
||||||
|
if (use_projection_weights) {
|
||||||
|
projection_weights_ = AddInput(weight_type);
|
||||||
|
if (use_projection_bias) {
|
||||||
|
projection_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
} else {
|
||||||
|
projection_bias_ = AddNullInput();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
projection_weights_ = AddNullInput();
|
||||||
|
projection_bias_ = AddNullInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the 2 state tensors.
|
||||||
|
output_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32,
|
||||||
|
{n_output_ * n_batch_}},
|
||||||
|
true);
|
||||||
|
cell_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32,
|
||||||
|
{n_cell_ * n_batch_}},
|
||||||
|
true);
|
||||||
|
|
||||||
|
if (use_cifg) {
|
||||||
|
input_layer_norm_weights_ = AddNullInput();
|
||||||
|
} else {
|
||||||
|
input_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
}
|
||||||
|
forget_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
cell_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
output_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||||
|
|
||||||
|
output_ = AddOutput(::tflite::TensorType_FLOAT32);
|
||||||
|
|
||||||
|
SetBuiltinOp(
|
||||||
|
BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
||||||
|
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
|
||||||
|
proj_clip, LSTMKernelType_FULL, false)
|
||||||
|
.Union());
|
||||||
|
BuildInterpreter(input_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCellToInputWeights(std::vector<float> f) {
|
||||||
|
SignedSymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCellToForgetWeights(std::vector<float> f) {
|
||||||
|
SignedSymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCellToOutputWeights(std::vector<float> f) {
|
||||||
|
SignedSymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInputLayerNormWeights(std::vector<float> f) {
|
||||||
|
PopulateTensor(input_layer_norm_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetForgetLayerNormWeights(std::vector<float> f) {
|
||||||
|
PopulateTensor(forget_layer_norm_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCellLayerNormWeights(std::vector<float> f) {
|
||||||
|
PopulateTensor(cell_layer_norm_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetOutputLayerNormWeights(std::vector<float> f) {
|
||||||
|
PopulateTensor(output_layer_norm_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInputGateBias(std::vector<float> f) {
|
||||||
|
PopulateTensor(input_gate_bias_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetForgetGateBias(std::vector<float> f) {
|
||||||
|
PopulateTensor(forget_gate_bias_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCellBias(std::vector<float> f) { PopulateTensor(cell_bias_, f); }
|
||||||
|
|
||||||
|
void SetOutputGateBias(std::vector<float> f) {
|
||||||
|
PopulateTensor(output_gate_bias_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetProjectionWeights(std::vector<float> f) {
|
||||||
|
SignedSymmetricQuantizeAndPopulate(projection_weights_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetProjectionBias(std::vector<float> f) {
|
||||||
|
PopulateTensor(projection_bias_, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInput(int offset, const float* begin, const float* end) {
|
||||||
|
PopulateTensor(input_, offset, const_cast<float*>(begin),
|
||||||
|
const_cast<float*>(end));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||||
|
|
||||||
|
int num_inputs() { return n_input_; }
|
||||||
|
int num_outputs() { return n_output_; }
|
||||||
|
int num_cells() { return n_cell_; }
|
||||||
|
int num_batches() { return n_batch_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int input_;
|
||||||
|
int input_to_input_weights_;
|
||||||
|
int input_to_forget_weights_;
|
||||||
|
int input_to_cell_weights_;
|
||||||
|
int input_to_output_weights_;
|
||||||
|
|
||||||
|
int recurrent_to_input_weights_;
|
||||||
|
int recurrent_to_forget_weights_;
|
||||||
|
int recurrent_to_cell_weights_;
|
||||||
|
int recurrent_to_output_weights_;
|
||||||
|
|
||||||
|
int cell_to_input_weights_;
|
||||||
|
int cell_to_forget_weights_;
|
||||||
|
int cell_to_output_weights_;
|
||||||
|
|
||||||
|
int input_layer_norm_weights_;
|
||||||
|
int forget_layer_norm_weights_;
|
||||||
|
int cell_layer_norm_weights_;
|
||||||
|
int output_layer_norm_weights_;
|
||||||
|
|
||||||
|
int input_gate_bias_;
|
||||||
|
int forget_gate_bias_;
|
||||||
|
int cell_bias_;
|
||||||
|
int output_gate_bias_;
|
||||||
|
|
||||||
|
int projection_weights_;
|
||||||
|
int projection_bias_;
|
||||||
|
|
||||||
|
int output_state_;
|
||||||
|
int cell_state_;
|
||||||
|
|
||||||
|
int output_;
|
||||||
|
|
||||||
|
int n_batch_;
|
||||||
|
int n_input_;
|
||||||
|
int n_cell_;
|
||||||
|
int n_output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BaseSparseLstmTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
// Weights of the Sparse Layer Norm LSTM model. Some are optional.
|
||||||
|
std::vector<float> input_to_input_weights_;
|
||||||
|
std::vector<float> input_to_cell_weights_;
|
||||||
|
std::vector<float> input_to_forget_weights_;
|
||||||
|
std::vector<float> input_to_output_weights_;
|
||||||
|
std::vector<float> input_gate_bias_;
|
||||||
|
std::vector<float> cell_gate_bias_;
|
||||||
|
std::vector<float> forget_gate_bias_;
|
||||||
|
std::vector<float> output_gate_bias_;
|
||||||
|
std::vector<float> recurrent_to_input_weights_;
|
||||||
|
std::vector<float> recurrent_to_cell_weights_;
|
||||||
|
std::vector<float> recurrent_to_forget_weights_;
|
||||||
|
std::vector<float> recurrent_to_output_weights_;
|
||||||
|
std::vector<float> cell_to_input_weights_;
|
||||||
|
std::vector<float> cell_to_forget_weights_;
|
||||||
|
std::vector<float> cell_to_output_weights_;
|
||||||
|
std::vector<float> input_layer_norm_weights_;
|
||||||
|
std::vector<float> forget_layer_norm_weights_;
|
||||||
|
std::vector<float> cell_layer_norm_weights_;
|
||||||
|
std::vector<float> output_layer_norm_weights_;
|
||||||
|
std::vector<float> projection_weights_;
|
||||||
|
|
||||||
|
std::vector<int> input_to_input_weights_size_;
|
||||||
|
std::vector<int> input_to_cell_weights_size_;
|
||||||
|
std::vector<int> input_to_forget_weights_size_;
|
||||||
|
std::vector<int> input_to_output_weights_size_;
|
||||||
|
std::vector<int> recurrent_to_input_weights_size_;
|
||||||
|
std::vector<int> recurrent_to_cell_weights_size_;
|
||||||
|
std::vector<int> recurrent_to_forget_weights_size_;
|
||||||
|
std::vector<int> recurrent_to_output_weights_size_;
|
||||||
|
|
||||||
|
int n_batch_;
|
||||||
|
int n_input_;
|
||||||
|
int n_cell_;
|
||||||
|
int n_output_;
|
||||||
|
float cell_clip_;
|
||||||
|
float proj_clip_;
|
||||||
|
|
||||||
|
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
|
||||||
|
std::vector<std::vector<float>> sparse_layer_norm_lstm_input_;
|
||||||
|
|
||||||
|
// Compares output up to tolerance to the result of the layer_norm_lstm given
|
||||||
|
// the input.
|
||||||
|
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||||
|
const std::vector<std::vector<float>>& output,
|
||||||
|
HybridSparseLSTMOpModel* sparse_layer_norm_lstm,
|
||||||
|
float tolerance = 1e-5) {
|
||||||
|
const int num_batches = input.size();
|
||||||
|
EXPECT_GT(num_batches, 0);
|
||||||
|
const int num_inputs = sparse_layer_norm_lstm->num_inputs();
|
||||||
|
EXPECT_GT(num_inputs, 0);
|
||||||
|
const int input_sequence_size = input[0].size() / num_inputs;
|
||||||
|
EXPECT_GT(input_sequence_size, 0);
|
||||||
|
for (int i = 0; i < input_sequence_size; ++i) {
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const float* batch_start = input[b].data() + i * num_inputs;
|
||||||
|
const float* batch_end = batch_start + num_inputs;
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm->SetInput(
|
||||||
|
b * sparse_layer_norm_lstm->num_inputs(), batch_start, batch_end);
|
||||||
|
}
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm->Invoke();
|
||||||
|
|
||||||
|
const int num_outputs = sparse_layer_norm_lstm->num_outputs();
|
||||||
|
std::vector<float> expected;
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
||||||
|
const float* golden_end_batch = golden_start_batch + num_outputs;
|
||||||
|
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
||||||
|
}
|
||||||
|
EXPECT_THAT(
|
||||||
|
sparse_layer_norm_lstm->GetOutput(),
|
||||||
|
ElementsAreArray(::tflite::ArrayFloatNear(expected, tolerance)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NoCifgPeepholeProjectionNoClippingSparseLstmTest
|
||||||
|
: public BaseSparseLstmTest {
|
||||||
|
void SetUp() override {
|
||||||
|
n_batch_ = 2;
|
||||||
|
n_input_ = 48;
|
||||||
|
n_cell_ = 4;
|
||||||
|
n_output_ = 16;
|
||||||
|
cell_clip_ = 0.0;
|
||||||
|
proj_clip_ = 0.0;
|
||||||
|
|
||||||
|
/* clang-format off */
|
||||||
|
input_to_input_weights_ = {
|
||||||
|
/* 1st row */
|
||||||
|
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||||
|
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||||
|
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 2nd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||||
|
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 3rd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||||
|
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 4th row */
|
||||||
|
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||||
|
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||||
|
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||||
|
input_to_input_weights_size_ = {4, 48};
|
||||||
|
|
||||||
|
input_to_forget_weights_ = {
|
||||||
|
/* 1st row */
|
||||||
|
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||||
|
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||||
|
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 2nd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||||
|
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 3rd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||||
|
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 4th row */
|
||||||
|
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||||
|
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||||
|
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||||
|
input_to_forget_weights_size_ = {4, 48};
|
||||||
|
|
||||||
|
input_to_cell_weights_ = {
|
||||||
|
/* 1st row */
|
||||||
|
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||||
|
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||||
|
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 2nd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||||
|
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 3rd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||||
|
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 4th row */
|
||||||
|
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||||
|
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||||
|
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||||
|
input_to_cell_weights_size_ = {4, 48};
|
||||||
|
|
||||||
|
input_to_output_weights_ = {
|
||||||
|
/* 1st row */
|
||||||
|
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||||
|
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||||
|
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 2nd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||||
|
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 3rd row */
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||||
|
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
/* 4th row */
|
||||||
|
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||||
|
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||||
|
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||||
|
input_to_output_weights_size_ = {4, 48};
|
||||||
|
|
||||||
|
input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
|
||||||
|
|
||||||
|
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
|
||||||
|
|
||||||
|
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
|
||||||
|
|
||||||
|
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
|
||||||
|
|
||||||
|
recurrent_to_input_weights_ = {
|
||||||
|
-0.2, -0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 1st row
|
||||||
|
0.1, -0.5, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 2nd row
|
||||||
|
-0.2, -0.3, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 3rd row
|
||||||
|
0.05, -0.2, -0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 4th row
|
||||||
|
};
|
||||||
|
recurrent_to_input_weights_size_ = {4, 16};
|
||||||
|
|
||||||
|
recurrent_to_cell_weights_ = {
|
||||||
|
-0.3, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 1st row
|
||||||
|
-0.3, 0.8, -0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 2nd row
|
||||||
|
-0.2, 0.3, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 3rd row
|
||||||
|
-0.6, -0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 4th row
|
||||||
|
};
|
||||||
|
recurrent_to_cell_weights_size_ = {4, 16};
|
||||||
|
|
||||||
|
recurrent_to_forget_weights_ = {
|
||||||
|
-0.5, -0.3, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 1st row
|
||||||
|
-0.2, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 2nd row
|
||||||
|
0.9, 0.3, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 3rd row
|
||||||
|
0.2, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 4th row
|
||||||
|
};
|
||||||
|
recurrent_to_forget_weights_size_ = {4, 16};
|
||||||
|
|
||||||
|
recurrent_to_output_weights_ = {
|
||||||
|
0.3, -0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 1st row
|
||||||
|
-0.2, -0.5, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 2nd row
|
||||||
|
-0.2, -0.6, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 3rd row
|
||||||
|
-0.4, -0.7, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, // 4th row
|
||||||
|
};
|
||||||
|
recurrent_to_output_weights_size_ = {4, 16};
|
||||||
|
|
||||||
|
cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
|
||||||
|
|
||||||
|
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
|
||||||
|
|
||||||
|
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
|
||||||
|
|
||||||
|
input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
|
||||||
|
forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
|
||||||
|
cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
|
||||||
|
output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
|
||||||
|
|
||||||
|
projection_weights_ = {
|
||||||
|
-0.1, 0.2, 0.01, -0.2, // 1st row
|
||||||
|
0.1, 0.5, 0.3, 0.08, // 2nd row
|
||||||
|
0.07, 0.2, -0.4, 0.2, // 3rd row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 4th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 5th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 6th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 7th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 8th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 9th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 10th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 11th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 12th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 13th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 14th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 15th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 16th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 17th row
|
||||||
|
0.0, 0.0, 0.0, 0.0, // 18th row
|
||||||
|
};
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm_input_ = {
|
||||||
|
// Batch0: 2 (input_sequence_size) * 45 (n_input_)
|
||||||
|
{
|
||||||
|
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||||
|
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||||
|
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||||
|
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
|
||||||
|
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
|
||||||
|
-1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
|
||||||
|
0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||||
|
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // seq 1
|
||||||
|
},
|
||||||
|
// Batch1: 2 (input_sequence_size) * 45 (n_input_)
|
||||||
|
{
|
||||||
|
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||||
|
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||||
|
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||||
|
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
|
||||||
|
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
|
||||||
|
-1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
|
||||||
|
0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||||
|
1.0, -2.5, 0.7, -1.9, 0.2, -1.0, 1.0, -1.0, // seq 1
|
||||||
|
},
|
||||||
|
};
|
||||||
|
/* clang-format on */
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,
|
||||||
|
HybridSparseLstmBlackBoxTest) {
|
||||||
|
TensorData input_weight = {};
|
||||||
|
input_weight.type = TensorType_FLOAT32;
|
||||||
|
input_weight.shape = {4, 48};
|
||||||
|
input_weight.traversal_order = {0, 1, 2};
|
||||||
|
input_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||||
|
input_weight.block_map = {1};
|
||||||
|
input_weight.block_size = {16};
|
||||||
|
TensorData recurrent_weight = {};
|
||||||
|
recurrent_weight.type = TensorType_FLOAT32;
|
||||||
|
recurrent_weight.shape = {4, 16};
|
||||||
|
recurrent_weight.traversal_order = {0, 1, 2};
|
||||||
|
recurrent_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||||
|
recurrent_weight.block_map = {1};
|
||||||
|
recurrent_weight.block_size = {16};
|
||||||
|
HybridSparseLSTMOpModel sparse_layer_norm_lstm(
|
||||||
|
n_batch_, n_input_, n_cell_, n_output_,
|
||||||
|
/*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
|
/*use_projection_weights=*/true,
|
||||||
|
/*use_projection_bias=*/false, cell_clip_, proj_clip_,
|
||||||
|
{
|
||||||
|
{n_batch_, n_input_}, // input tensor
|
||||||
|
|
||||||
|
{input_to_input_weights_size_},
|
||||||
|
{input_to_forget_weights_size_},
|
||||||
|
{input_to_cell_weights_size_},
|
||||||
|
{input_to_output_weights_size_},
|
||||||
|
|
||||||
|
{recurrent_to_input_weights_size_},
|
||||||
|
{recurrent_to_forget_weights_size_},
|
||||||
|
{recurrent_to_cell_weights_size_},
|
||||||
|
{recurrent_to_output_weights_size_},
|
||||||
|
|
||||||
|
{n_cell_}, // cell_to_input_weight tensor
|
||||||
|
{n_cell_}, // cell_to_forget_weight tensor
|
||||||
|
{n_cell_}, // cell_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell_}, // input_gate_bias tensor
|
||||||
|
{n_cell_}, // forget_gate_bias tensor
|
||||||
|
{n_cell_}, // cell_bias tensor
|
||||||
|
{n_cell_}, // output_gate_bias tensor
|
||||||
|
|
||||||
|
{n_output_, n_cell_}, // projection_weight tensor
|
||||||
|
{0}, // projection_bias tensor
|
||||||
|
|
||||||
|
{n_output_ * n_batch_}, // output_state tensor
|
||||||
|
{n_cell_ * n_batch_}, // cell_state tensor
|
||||||
|
|
||||||
|
{n_cell_}, // input_layer_norm_weight tensor
|
||||||
|
{n_cell_}, // forget_layer_norm_weight tensor
|
||||||
|
{n_cell_}, // cell_layer_norm_weight tensor
|
||||||
|
{n_cell_}, // output_layer_norm_weight tensor
|
||||||
|
},
|
||||||
|
input_weight, input_to_input_weights_, input_to_forget_weights_,
|
||||||
|
input_to_cell_weights_, input_to_output_weights_, recurrent_weight,
|
||||||
|
recurrent_to_input_weights_, recurrent_to_forget_weights_,
|
||||||
|
recurrent_to_cell_weights_, recurrent_to_output_weights_);
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||||
|
sparse_layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||||
|
sparse_layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||||
|
sparse_layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||||
|
sparse_layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||||
|
sparse_layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
|
||||||
|
sparse_layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
|
||||||
|
sparse_layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
|
||||||
|
sparse_layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
|
||||||
|
|
||||||
|
sparse_layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||||
|
|
||||||
|
/* clang-format off */
|
||||||
|
const std::vector<std::vector<float>> sparse_layer_norm_lstm_golden_output = {
|
||||||
|
{
|
||||||
|
// Batch0: 2 (input_sequence_size) * 3 (n_output_)
|
||||||
|
0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Batch1: 3 (input_sequence_size) * 3 (n_output_)
|
||||||
|
0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
}};
|
||||||
|
/* clang-format on */
|
||||||
|
|
||||||
|
VerifyGoldens(sparse_layer_norm_lstm_input_,
|
||||||
|
sparse_layer_norm_lstm_golden_output, &sparse_layer_norm_lstm);
|
||||||
|
}
|
||||||
|
|
||||||
// Test parameter controls asymmetric_quantize_inputs in LSTMOpModel.
|
// Test parameter controls asymmetric_quantize_inputs in LSTMOpModel.
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
Parameterized, LstmOpTest,
|
Parameterized, LstmOpTest,
|
||||||
|
@ -214,9 +214,82 @@ class SingleOpModel {
|
|||||||
return AddConstInput(TensorData{type, shape}, data);
|
return AddConstInput(TensorData{type, shape}, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(b/166202747): Use a better way to do type specialization. Reduce
|
||||||
|
// duplicate code in the two functions below.
|
||||||
|
int AddConstSparseInput(const TensorData& t,
|
||||||
|
const std::vector<int8_t>& data) {
|
||||||
|
int id = tensors_.size();
|
||||||
|
const int dims_count = t.traversal_order.size();
|
||||||
|
std::vector<int8_t> dense_data(data);
|
||||||
|
|
||||||
|
tflite::optimize::sparsity::FormatConverter<int8_t> converter(
|
||||||
|
t.shape, t.traversal_order, t.format, t.block_size, t.block_map);
|
||||||
|
converter.DenseToSparse(dense_data.data());
|
||||||
|
|
||||||
|
const auto dim_metadata = converter.GetDimMetadata();
|
||||||
|
const auto sparse_data = converter.GetData();
|
||||||
|
|
||||||
|
// Build sparsity parameter.
|
||||||
|
std::vector<flatbuffers::Offset<DimensionMetadata>> fb_dim_metadata(
|
||||||
|
dims_count);
|
||||||
|
for (int i = 0; i < dims_count; i++) {
|
||||||
|
const int metadata_idx = 2 * i;
|
||||||
|
if (i < t.shape.size() &&
|
||||||
|
t.format[t.traversal_order[i]] == kTfLiteDimSparseCSR) {
|
||||||
|
auto array_segments =
|
||||||
|
CreateInt32Vector(builder_,
|
||||||
|
builder_.CreateVector(dim_metadata[metadata_idx]))
|
||||||
|
.Union();
|
||||||
|
auto array_indices =
|
||||||
|
CreateInt32Vector(
|
||||||
|
builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1]))
|
||||||
|
.Union();
|
||||||
|
fb_dim_metadata[i] = CreateDimensionMetadata(
|
||||||
|
builder_, DimensionType_SPARSE_CSR, 0,
|
||||||
|
SparseIndexVector_Int32Vector, array_segments,
|
||||||
|
SparseIndexVector_Int32Vector, array_indices);
|
||||||
|
} else {
|
||||||
|
fb_dim_metadata[i] = CreateDimensionMetadata(
|
||||||
|
builder_, DimensionType_DENSE, dim_metadata[metadata_idx][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<SparsityParameters> s_param = CreateSparsityParameters(
|
||||||
|
builder_, builder_.CreateVector(t.traversal_order),
|
||||||
|
builder_.CreateVector(t.block_map),
|
||||||
|
builder_.CreateVector(fb_dim_metadata));
|
||||||
|
|
||||||
|
int buffer_id = 0;
|
||||||
|
if (!data.empty()) {
|
||||||
|
// Initialize buffers list with empty buffer to allow for non-const
|
||||||
|
// tensors.
|
||||||
|
if (buffers_.empty()) {
|
||||||
|
buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add compressed data as a Buffer to buffers list.
|
||||||
|
buffer_id = buffers_.size();
|
||||||
|
auto data_buffer = builder_.CreateVector(
|
||||||
|
reinterpret_cast<const uint8_t*>(sparse_data.data()),
|
||||||
|
sparse_data.size());
|
||||||
|
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||||
|
}
|
||||||
|
|
||||||
|
tensors_.push_back(CreateTensor(
|
||||||
|
builder_, builder_.CreateVector<int>(t.shape), t.type,
|
||||||
|
/*buffer=*/buffer_id,
|
||||||
|
/*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param));
|
||||||
|
|
||||||
|
inputs_.push_back(id);
|
||||||
|
tensor_data_[id] = t;
|
||||||
|
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
// Add a constant sparse tensor as input.
|
// Add a constant sparse tensor as input.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int AddConstSparseInput(const TensorData& t, std::initializer_list<T> data) {
|
int AddConstSparseInput(const TensorData& t, const std::vector<T>& data,
|
||||||
|
bool symmetric_quantize = false) {
|
||||||
int id = tensors_.size();
|
int id = tensors_.size();
|
||||||
const int dims_count = t.traversal_order.size();
|
const int dims_count = t.traversal_order.size();
|
||||||
std::vector<T> dense_data(data);
|
std::vector<T> dense_data(data);
|
||||||
@ -258,8 +331,9 @@ class SingleOpModel {
|
|||||||
builder_.CreateVector(t.block_map),
|
builder_.CreateVector(t.block_map),
|
||||||
builder_.CreateVector(fb_dim_metadata));
|
builder_.CreateVector(fb_dim_metadata));
|
||||||
|
|
||||||
|
flatbuffers::Offset<QuantizationParameters> q_params = 0;
|
||||||
int buffer_id = 0;
|
int buffer_id = 0;
|
||||||
if (data.size()) {
|
if (!data.empty()) {
|
||||||
// Initialize buffers list with empty buffer to allow for non-const
|
// Initialize buffers list with empty buffer to allow for non-const
|
||||||
// tensors.
|
// tensors.
|
||||||
if (buffers_.empty()) {
|
if (buffers_.empty()) {
|
||||||
@ -268,16 +342,31 @@ class SingleOpModel {
|
|||||||
|
|
||||||
// Add compressed data as a Buffer to buffers list.
|
// Add compressed data as a Buffer to buffers list.
|
||||||
buffer_id = buffers_.size();
|
buffer_id = buffers_.size();
|
||||||
auto data_buffer = builder_.CreateVector(
|
if (symmetric_quantize) {
|
||||||
reinterpret_cast<const uint8_t*>(sparse_data.data()),
|
const int length = sparse_data.size();
|
||||||
sizeof(T) * sparse_data.size());
|
std::vector<int8_t> q(length);
|
||||||
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
float min, max, scaling_factor;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
sparse_data.data(), length, q.data(), &min, &max, &scaling_factor);
|
||||||
|
q_params = CreateQuantizationParameters(
|
||||||
|
builder_, 0, 0, builder_.CreateVector<float>({scaling_factor}),
|
||||||
|
builder_.CreateVector<int64_t>({0}));
|
||||||
|
auto data_buffer = builder_.CreateVector(
|
||||||
|
reinterpret_cast<const uint8_t*>(q.data()), q.size());
|
||||||
|
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||||
|
} else {
|
||||||
|
auto data_buffer = builder_.CreateVector(
|
||||||
|
reinterpret_cast<const uint8_t*>(sparse_data.data()),
|
||||||
|
sizeof(T) * sparse_data.size());
|
||||||
|
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tensors_.push_back(CreateTensor(
|
tensors_.push_back(
|
||||||
builder_, builder_.CreateVector<int>(t.shape), t.type,
|
CreateTensor(builder_, builder_.CreateVector<int>(t.shape),
|
||||||
/*buffer=*/buffer_id,
|
symmetric_quantize ? TensorType_INT8 : t.type,
|
||||||
/*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param));
|
/*buffer=*/buffer_id,
|
||||||
|
/*name=*/0, q_params, /*is_variable=*/false, s_param));
|
||||||
|
|
||||||
inputs_.push_back(id);
|
inputs_.push_back(id);
|
||||||
tensor_data_[id] = t;
|
tensor_data_[id] = t;
|
||||||
|
@ -650,11 +650,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||||
const int row_sums_size = row_sums->dims->data[0];
|
const int row_sums_size = row_sums->dims->data[0];
|
||||||
return lstm_eval::EvalHybrid(
|
return lstm_eval::EvalHybrid(
|
||||||
input, input_to_input_weights, input_to_forget_weights,
|
input, input_to_input_weights,
|
||||||
input_to_cell_weights, input_to_output_weights,
|
/*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
|
||||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
/*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
|
||||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
/*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
|
||||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
/*input_to_output_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_input_weights,
|
||||||
|
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_forget_weights,
|
||||||
|
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_cell_weights,
|
||||||
|
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||||
|
recurrent_to_output_weights,
|
||||||
|
/*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights,
|
||||||
|
cell_to_forget_weights, cell_to_output_weights,
|
||||||
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
||||||
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
||||||
/*aux_input=*/nullptr,
|
/*aux_input=*/nullptr,
|
||||||
@ -663,7 +672,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
/*aux_input_to_cell_weights=*/nullptr,
|
/*aux_input_to_cell_weights=*/nullptr,
|
||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||||
projection_weights, projection_bias, &lstm_params,
|
projection_weights, /*projection_weights_ledger*/ nullptr,
|
||||||
|
projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/true, time_major,
|
/*forward_sequence=*/true, time_major,
|
||||||
/*output_offset=*/0, scratch_buffer,
|
/*output_offset=*/0, scratch_buffer,
|
||||||
GetTemporary(context, node, kInputScalingFactors),
|
GetTemporary(context, node, kInputScalingFactors),
|
||||||
|
Loading…
Reference in New Issue
Block a user