From 8744e4b2b9e766debdb0e7bca2861e7134ce8339 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Tue, 1 Sep 2020 12:48:41 -0700 Subject: [PATCH] Add builtin sparse LSTM kernel. PiperOrigin-RevId: 329562447 Change-Id: I5c407b513fbc86d21f6ea2d626da7b69dcd38bc7 --- .../kernels/bidirectional_sequence_lstm.cc | 40 +- tensorflow/lite/kernels/densify_test.cc | 18 +- .../lite/kernels/fully_connected_test.cc | 2 +- tensorflow/lite/kernels/lstm.cc | 309 ++++++++- tensorflow/lite/kernels/lstm_eval.cc | 201 ++++-- tensorflow/lite/kernels/lstm_eval.h | 16 +- tensorflow/lite/kernels/lstm_eval_test.cc | 18 +- tensorflow/lite/kernels/lstm_test.cc | 614 ++++++++++++++++++ tensorflow/lite/kernels/test_util.h | 109 +++- .../kernels/unidirectional_sequence_lstm.cc | 22 +- 10 files changed, 1230 insertions(+), 119 deletions(-) diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 1ce131a96ac..45d973d1d98 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -1136,10 +1136,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int fw_row_sums_size = fw_row_sums->dims->data[0]; const int bw_row_sums_size = bw_row_sums->dims->data[0]; TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( - input, fw_input_to_input_weights, fw_input_to_forget_weights, - fw_input_to_cell_weights, fw_input_to_output_weights, - fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, - fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + input, fw_input_to_input_weights, + /*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights, + /*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights, + /*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights, + /*input_to_output_weights_ledger*/ nullptr, + fw_recurrent_to_input_weights, + /*recurrent_to_input_weights_ledger*/ nullptr, + fw_recurrent_to_forget_weights, + /*recurrent_to_forget_weights_ledger*/ nullptr, + fw_recurrent_to_cell_weights, + /*recurrent_to_cell_weights_ledger*/ nullptr, + fw_recurrent_to_output_weights, + /*recurrent_to_output_weights_ledger*/ nullptr, fw_cell_to_input_weights, fw_cell_to_forget_weights, fw_cell_to_output_weights, /*input_layer_norm_coefficients=*/nullptr, @@ -1149,7 +1158,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias, - fw_output_gate_bias, fw_projection_weights, fw_projection_bias, + fw_output_gate_bias, fw_projection_weights, + /*projection_weights_ledger*/ nullptr, fw_projection_bias, &lstm_params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), @@ -1167,10 +1177,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( - bw_input, bw_input_to_input_weights, bw_input_to_forget_weights, - bw_input_to_cell_weights, bw_input_to_output_weights, - bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, - bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_input, bw_input_to_input_weights, + /*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights, + /*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights, + /*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights, + /*input_to_output_weights_ledger*/ nullptr, + bw_recurrent_to_input_weights, + /*recurrent_to_input_weights_ledger*/ nullptr, + bw_recurrent_to_forget_weights, + /*recurrent_to_forget_weights_ledger*/ nullptr, + bw_recurrent_to_cell_weights, + /*recurrent_to_cell_weights_ledger*/ nullptr, + bw_recurrent_to_output_weights, + /*recurrent_to_output_weights_ledger*/ nullptr, bw_cell_to_input_weights, bw_cell_to_forget_weights, bw_cell_to_output_weights, /*input_layer_norm_coefficients=*/nullptr, @@ -1180,7 +1199,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias, - bw_output_gate_bias, bw_projection_weights, bw_projection_bias, + bw_output_gate_bias, bw_projection_weights, + /*projection_weights_ledger*/ nullptr, bw_projection_bias, &lstm_params, /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), diff --git a/tensorflow/lite/kernels/densify_test.cc b/tensorflow/lite/kernels/densify_test.cc index d453606cf2e..384a2bed3c7 100644 --- a/tensorflow/lite/kernels/densify_test.cc +++ b/tensorflow/lite/kernels/densify_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include #include #include @@ -42,7 +41,7 @@ using ::testing::ElementsAreArray; template class DensifyOpModel : public SingleOpModel { public: - DensifyOpModel(const TensorData& input, std::initializer_list input_data, + DensifyOpModel(const TensorData& input, const std::vector& input_data, int version = 1) { input_ = AddConstSparseInput(input, input_data); output_ = AddOutput({input.type, input.shape}); @@ -65,9 +64,8 @@ class DensifyOpModel : public SingleOpModel { }; TEST(DensifyOpTest, Float) { - std::initializer_list dense_values = {6, 0, 9, 8, 0, 0, - 0, 0, 5, 0, 0, 7}; - std::initializer_list sparse_values = {6, 9, 8, 5, 7}; + std::vector dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7}; + std::vector sparse_values = {6, 9, 8, 5, 7}; TensorData input = {}; input.type = TensorType_FLOAT32; input.shape = {3, 4}; @@ -80,9 +78,8 @@ TEST(DensifyOpTest, Float) { } TEST(DensifyOpTest, Float3D) { - std::initializer_list dense_values = {6, 0, 9, 8, 0, 0, - 0, 0, 5, 0, 0, 7}; - std::initializer_list sparse_values = {6, 9, 8, 5, 7}; + std::vector dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7}; + std::vector sparse_values = {6, 9, 8, 5, 7}; TensorData input = {}; input.type = TensorType_FLOAT32; input.shape = {3, 2, 2}; @@ -95,9 +92,8 @@ TEST(DensifyOpTest, Float3D) { } TEST(DensifyOpTest, Int8) { - std::initializer_list dense_values = {6, 0, 9, 8, 0, 0, - 0, 0, 5, 0, 0, 7}; - std::initializer_list sparse_values = {6, 9, 8, 5, 7}; + std::vector dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7}; + std::vector sparse_values = {6, 9, 8, 5, 7}; TensorData input = {}; input.type = TensorType_INT8; input.shape = {3, 4}; diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 7f02ed079bd..9a80c4eebfa 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -1144,7 +1144,7 @@ class SparseFullyConnectedOpModel : public SingleOpModel { SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units, int batches, const TensorData& input, const TensorData& weights, - std::initializer_list weights_data, + const std::vector& weights_data, int num_threads = 1) : batches_(batches), units_(units) { int total_input_size = 1; diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index c39f715446b..6d67f759ce8 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -55,6 +55,10 @@ struct OpData { int scratch_tensor_index; lstm_eval::IntegerLstmParameter integer_lstm_param; bool compute_row_sums; + + // Only used for sparse hybrid lstm kernels. + int ledger_index; + bool ledger_initialized; }; namespace full { @@ -77,6 +81,63 @@ enum HybridTemporaryTensor { kNumHybridTemporaryTensors = 12, }; +constexpr int kLedgersToAdd = 9; +constexpr int kInputToInputWeightsLedgerOffset = 0; +constexpr int kInputToForgetWeightsLedgerOffset = 1; +constexpr int kInputToCellWeightsLedgerOffset = 2; +constexpr int kInputToOutputWeightsLedgerOffset = 3; +constexpr int kRecurrentToInputWeightsLedgerOffset = 4; +constexpr int kRecurrentToForgetWeightsLedgerOffset = 5; +constexpr int kRecurrentToCellWeightsLedgerOffset = 6; +constexpr int kRecurrentToOutputWeightsLedgerOffset = 7; +constexpr int kProjectionWeightsLedgerOffset = 8; + +TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context, + TfLiteTensor* ledger) { + ledger->type = kTfLiteUInt8; + ledger->allocation_type = kTfLiteArenaRwPersistent; + if (sparsity == nullptr) { + return kTfLiteOk; + } + TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1); + ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size + + sparsity->dim_metadata[1].array_segments->size - 1; + return context->ResizeTensor(context, ledger, ledger_size); +} + +TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) { + if (sparsity == nullptr) { + return kTfLiteOk; + } + + const auto* array_segments = sparsity->dim_metadata[1].array_segments; + const auto* array_indices = sparsity->dim_metadata[1].array_indices; + uint8_t* output_data = GetTensorData(ledger); + int output_data_ptr = 0; + + for (int i = 0; i < array_segments->size - 1; i++) { + int row_start = array_segments->data[i]; + int row_end = array_segments->data[i + 1]; + if (row_end - row_start > UINT8_MAX) { + return kTfLiteError; + } + // Copy num of non-zero blocks in row i. + output_data[output_data_ptr] = static_cast(row_end - row_start); + output_data_ptr++; + + for (int j = row_start; j < row_end; j++) { + if (array_indices->data[j] > UINT8_MAX) { + return kTfLiteError; + } + // Copy indices of non-zero blocks in row i. + output_data[output_data_ptr] = + static_cast(array_indices->data[j]); + output_data_ptr++; + } + } + return kTfLiteOk; +} + TfLiteStatus PopulateQuantizedLstmParams8x8_16( TfLiteContext* context, TfLiteNode* node, lstm_eval::IntegerLstmParameter* integer_lstm_param) { @@ -744,6 +805,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // TODO(b/159066113): maybe just add the minimum required temp tensors? context->AddTensors(context, kNumHybridTemporaryTensors, &op_data->scratch_tensor_index); + // Tensors used for the sparse hybrid kernel. + context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd, + &op_data->ledger_index); return op_data; } @@ -1239,6 +1303,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // The weights are of consistent type, so it suffices to check one. const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights); + const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr); + // The type of Integer LSTM. const int num_intermediate_tensors = node->intermediates->size; if (is_integer) { @@ -1251,7 +1317,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors); + if (is_sparse_op) { + node->temporaries = + TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd); + } else { + node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors); + } } else if (is_integer) { if (is_8x8_16) { node->temporaries = TfLiteIntArrayCreate(6); @@ -1289,7 +1360,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } if (is_hybrid_op) { - op_data->compute_row_sums = true; + if (!is_sparse_op) { + op_data->compute_row_sums = true; + } // Allocate temporary tensors to store quantized values of input, // output_state and cell_state tensors. node->temporaries->data[kInputQuantized] = @@ -1454,6 +1527,125 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, row_sums, row_sums_size)); } + + if (is_sparse_op) { + op_data->ledger_initialized = false; + int offset = kNumHybridTemporaryTensors; + { + node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] = + op_data->ledger_index + kInputToInputWeightsLedgerOffset; + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_input_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToInputWeightsLedgerOffset]; + auto status = make_ledger(input_to_input_weights == nullptr + ? nullptr + : input_to_input_weights->sparsity, + context, input_to_input_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] = + op_data->ledger_index + kInputToForgetWeightsLedgerOffset; + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_forget_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToForgetWeightsLedgerOffset]; + auto status = make_ledger(input_to_forget_weights->sparsity, context, + input_to_forget_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] = + op_data->ledger_index + kInputToCellWeightsLedgerOffset; + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_cell_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToCellWeightsLedgerOffset]; + auto status = make_ledger(input_to_cell_weights->sparsity, context, + input_to_cell_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] = + op_data->ledger_index + kInputToOutputWeightsLedgerOffset; + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + TfLiteTensor* input_to_output_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToOutputWeightsLedgerOffset]; + auto status = make_ledger(input_to_output_weights->sparsity, context, + input_to_output_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] = + op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset; + const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( + context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_input_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToInputWeightsLedgerOffset]; + auto status = make_ledger(recurrent_to_input_weights == nullptr + ? nullptr + : recurrent_to_input_weights->sparsity, + context, recurrent_to_input_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries + ->data[offset + kRecurrentToForgetWeightsLedgerOffset] = + op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset; + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToForgetWeightsLedgerOffset]; + auto status = make_ledger(recurrent_to_forget_weights->sparsity, + context, recurrent_to_forget_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] = + op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset; + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToCellWeightsLedgerOffset]; + auto status = make_ledger(recurrent_to_cell_weights->sparsity, context, + recurrent_to_cell_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries + ->data[offset + kRecurrentToOutputWeightsLedgerOffset] = + op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset; + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TfLiteTensor* recurrent_to_output_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToOutputWeightsLedgerOffset]; + auto status = make_ledger(recurrent_to_output_weights->sparsity, + context, recurrent_to_output_weights_ledger); + if (status != kTfLiteOk) return status; + } + { + node->temporaries->data[offset + kProjectionWeightsLedgerOffset] = + op_data->ledger_index + kProjectionWeightsLedgerOffset; + const TfLiteTensor* projection_weights = + GetInput(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_weights_ledger = + &context->tensors[op_data->ledger_index + + kProjectionWeightsLedgerOffset]; + auto status = make_ledger(projection_weights->sparsity, context, + projection_weights_ledger); + if (status != kTfLiteOk) return status; + } + } } if (is_integer) { @@ -1624,14 +1816,116 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: case kTfLiteInt8: { const bool is_hybrid = (input->type == kTfLiteFloat32); + const bool is_sparse = input_to_output_weights->sparsity != nullptr; if (is_hybrid) { TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); const int row_sums_size = row_sums->dims->data[0]; + if (is_sparse) { + TfLiteTensor* input_to_input_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToInputWeightsLedgerOffset]; + TfLiteTensor* input_to_forget_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToForgetWeightsLedgerOffset]; + TfLiteTensor* input_to_cell_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToCellWeightsLedgerOffset]; + TfLiteTensor* input_to_output_weights_ledger = + &context->tensors[op_data->ledger_index + + kInputToOutputWeightsLedgerOffset]; + TfLiteTensor* recurrent_to_input_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToInputWeightsLedgerOffset]; + TfLiteTensor* recurrent_to_forget_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToForgetWeightsLedgerOffset]; + TfLiteTensor* recurrent_to_cell_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToCellWeightsLedgerOffset]; + TfLiteTensor* recurrent_to_output_weights_ledger = + &context->tensors[op_data->ledger_index + + kRecurrentToOutputWeightsLedgerOffset]; + TfLiteTensor* projection_weights_ledger = + &context->tensors[op_data->ledger_index + + kProjectionWeightsLedgerOffset]; + if (!op_data->ledger_initialized) { + copy_ledger(input_to_input_weights == nullptr + ? nullptr + : input_to_input_weights->sparsity, + input_to_input_weights_ledger); + copy_ledger(input_to_forget_weights->sparsity, + input_to_forget_weights_ledger); + copy_ledger(input_to_cell_weights->sparsity, + input_to_cell_weights_ledger); + copy_ledger(input_to_output_weights->sparsity, + input_to_output_weights_ledger); + copy_ledger(recurrent_to_input_weights == nullptr + ? nullptr + : recurrent_to_input_weights->sparsity, + recurrent_to_input_weights_ledger); + copy_ledger(recurrent_to_forget_weights->sparsity, + recurrent_to_forget_weights_ledger); + copy_ledger(recurrent_to_cell_weights->sparsity, + recurrent_to_cell_weights_ledger); + copy_ledger(recurrent_to_output_weights->sparsity, + recurrent_to_output_weights_ledger); + copy_ledger(projection_weights->sparsity, + projection_weights_ledger); + op_data->ledger_initialized = true; + } + return lstm_eval::EvalHybrid( + input, input_to_input_weights, input_to_input_weights_ledger, + input_to_forget_weights, input_to_forget_weights_ledger, + input_to_cell_weights, input_to_cell_weights_ledger, + input_to_output_weights, input_to_output_weights_ledger, + recurrent_to_input_weights, recurrent_to_input_weights_ledger, + recurrent_to_forget_weights, recurrent_to_forget_weights_ledger, + recurrent_to_cell_weights, recurrent_to_cell_weights_ledger, + recurrent_to_output_weights, recurrent_to_output_weights_ledger, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_coefficients, + forget_layer_norm_coefficients, cell_layer_norm_coefficients, + output_layer_norm_coefficients, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_gate_bias, output_gate_bias, + projection_weights, projection_weights_ledger, projection_bias, + params, + /*forward_sequence=*/true, /*time_major=*/true, + /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer), + GetTemporary(context, node, kInputScalingFactors), + /*aux_input_sf=*/nullptr, + GetTemporary(context, node, kOutputStateScalingFactors), + GetTemporary(context, node, kProductScalingFactors), + GetTemporary(context, node, kRecoveredCellWeights), + GetTemporary(context, node, kInputQuantized), + /*aux_input_quantized=*/nullptr, + GetTemporary(context, node, kOutputStateQuantized), + GetTemporary(context, node, kCellStateQuantized), output_state, + cell_state, GetTemporary(context, node, kAccumScratch), output, + GetTemporary(context, node, kInputZeroPoints), + /*aux_input_zp=*/nullptr, + GetTemporary(context, node, kOutputStateZeroPoints), row_sums, + row_sums_size, &op_data->compute_row_sums, + CpuBackendContext::GetFromContext(context)); + } return lstm_eval::EvalHybrid( - input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, + input, input_to_input_weights, + /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, + /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, + /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, + /*input_to_output_weights_ledger*/ nullptr, + recurrent_to_input_weights, + /*recurrent_to_input_weights_ledger*/ nullptr, + recurrent_to_forget_weights, + /*recurrent_to_forget_weights_ledger*/ nullptr, + recurrent_to_cell_weights, + /*recurrent_to_cell_weights_ledger*/ nullptr, + recurrent_to_output_weights, + /*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, input_layer_norm_coefficients, forget_layer_norm_coefficients, cell_layer_norm_coefficients, @@ -1641,7 +1935,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias, - projection_weights, projection_bias, params, + projection_weights, /*projection_weights_ledger*/ nullptr, + projection_bias, params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer), GetTemporary(context, node, kInputScalingFactors), diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index e11a7c5a026..695100fa92f 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -312,6 +312,7 @@ void CalculateLstmGateHybrid( // Input and weights const int8_t* input, const float* input_sf, const int32_t* input_zp, const int8_t* input_to_gate_weights, + const uint8_t* input_to_gate_weights_ledger, const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums, // Aux input and weights const int8_t* aux_input, const float* aux_input_sf, @@ -321,6 +322,7 @@ void CalculateLstmGateHybrid( // Output state and weights const int8_t* output_state, const float* output_state_sf, const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights, + const uint8_t* recurrent_to_gate_weights_ledger, const float recurrent_to_gate_weights_scale, int32_t* recurrent_to_gate_row_sums, // Cell state and weights (peephole LSTM) @@ -356,11 +358,22 @@ void CalculateLstmGateHybrid( // For each batch and cell: compute input_weight * input. // Skip if input is all zeros. if (!is_input_all_zeros) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_gate_weights, n_cell, n_input, input, - input_to_gate_weights_scale, input_sf, n_batch, gate, - /*per_channel_scale=*/nullptr, input_zp, accum_scratch, - input_to_gate_row_sums, compute_row_sums, scratch0, context); + if (input_to_gate_weights_ledger != nullptr) { + std::vector scales(n_batch); + for (int i = 0; i < n_batch; i++) { + scales[i] = input_to_gate_weights_scale * input_sf[i]; + } + tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input, + input, scales.data(), n_batch, gate); + + } else { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_gate_weights, n_cell, n_input, input, + input_to_gate_weights_scale, input_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch, + input_to_gate_row_sums, compute_row_sums, scratch0, context); + } } // For each batch and cell: compute aux_input_weight * aux_input. // Skip if auxiliary input is not available or all zeros. @@ -374,11 +387,21 @@ void CalculateLstmGateHybrid( // For each batch and cell: compute recurrent_weight * output_state. // Skip if output state is all zeros. if (!is_output_state_all_zeros) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_gate_weights, n_cell, n_output, output_state, - recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate, - /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch, - recurrent_to_gate_row_sums, compute_row_sums, scratch0, context); + if (recurrent_to_gate_weights_ledger != nullptr) { + std::vector scales(n_batch); + for (int i = 0; i < n_batch; i++) { + scales[i] = recurrent_to_gate_weights_scale * input_sf[i]; + } + tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell, + n_output, output_state, scales.data(), n_batch, gate); + } else { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_gate_weights, n_cell, n_output, output_state, + recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch, + recurrent_to_gate_row_sums, compute_row_sums, scratch0, context); + } } // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) if (use_peephole) { @@ -422,11 +445,12 @@ void CalculateLstmGateHybrid( void CalculateLstmOutputHybrid( int n_batch, int n_cell, int n_output, const float* cell_state, const float* output_gate, TfLiteFusedActivation activation, - const int8_t* projection_weights, float projection_weights_scale, - const float* projection_bias, const float proj_clip, float* output_state, - bool asymmetric_quantize_inputs, int32_t* projection_weights_row_sums, - bool* compute_row_sums, CpuBackendContext* context, float* scratch0, - int8_t* scratch1, float* scratch2, int32_t* scratch3, int32_t* scratch4) { + const int8_t* projection_weights, const uint8_t* projection_weights_ledger, + float projection_weights_scale, const float* projection_bias, + const float proj_clip, float* output_state, bool asymmetric_quantize_inputs, + int32_t* projection_weights_row_sums, bool* compute_row_sums, + CpuBackendContext* context, float* scratch0, int8_t* scratch1, + float* scratch2, int32_t* scratch3, int32_t* scratch4) { tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch0); tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0, @@ -447,11 +471,21 @@ void CalculateLstmOutputHybrid( tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1, scratch2, scratch3, asymmetric_quantize_inputs); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights, n_output, n_cell, scratch1, - projection_weights_scale, scratch2, n_batch, output_state, - /*per_channel_scale=*/nullptr, scratch3, scratch4, - projection_weights_row_sums, compute_row_sums, scratch2, context); + if (projection_weights_ledger != nullptr) { + std::vector scales(n_batch); + for (int i = 0; i < n_batch; i++) { + scales[i] = projection_weights_scale * scratch2[i]; + } + tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + projection_weights, projection_weights_ledger, n_output, n_cell, + scratch1, scales.data(), n_batch, output_state); + } else { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights, n_output, n_cell, scratch1, + projection_weights_scale, scratch2, n_batch, output_state, + /*per_channel_scale=*/nullptr, scratch3, scratch4, + projection_weights_row_sums, compute_row_sums, scratch2, context); + } } if (proj_clip > 0.0f) { tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip); @@ -955,11 +989,16 @@ inline void LstmStepFloat( // output_ptr - size 'n_batch * output_batch_leading_dim' inline void LstmStepHybrid( const float* input_ptr, const int8_t* input_to_input_weights_ptr, + const uint8_t* input_to_input_weights_ledger_ptr, float input_to_input_weights_scale, const int8_t* input_to_forget_weights_ptr, + const uint8_t* input_to_forget_weights_ledger_ptr, float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_cell_weights_ptr, + const uint8_t* input_to_cell_weights_ledger_ptr, + float input_to_cell_weights_scale, const int8_t* input_to_output_weights_ptr, + const uint8_t* input_to_output_weights_ledger_ptr, float input_to_output_weights_scale, const float* aux_input_ptr, const int8_t* aux_input_to_input_weights_ptr, float aux_input_to_input_weights_scale, @@ -970,12 +1009,16 @@ inline void LstmStepHybrid( const int8_t* aux_input_to_output_weights_ptr, float aux_input_to_output_weights_scale, const int8_t* recurrent_to_input_weights_ptr, + const uint8_t* recurrent_to_input_weights_ledger_ptr, float recurrent_to_input_weights_scale, const int8_t* recurrent_to_forget_weights_ptr, + const uint8_t* recurrent_to_forget_weights_ledger_ptr, float recurrent_to_forget_weights_scale, const int8_t* recurrent_to_cell_weights_ptr, + const uint8_t* recurrent_to_cell_weights_ledger_ptr, float recurrent_to_cell_weights_scale, const int8_t* recurrent_to_output_weights_ptr, + const uint8_t* recurrent_to_output_weights_ledger_ptr, float recurrent_to_output_weights_scale, const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, const int8_t* cell_to_forget_weights_ptr, @@ -988,19 +1031,21 @@ inline void LstmStepHybrid( const float* output_layer_norm_coefficients_ptr, const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, - const int8_t* projection_weights_ptr, float projection_weights_scale, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - int output_batch_leading_dim, float* scratch0, float* scratch1, - float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf, - float* output_state_sf, float* scaling_factors_scratch, - float* recovered_cell_weights, int8_t* quantized_input_ptr, - int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, - int8_t* quantized_output_scratch, float* output_state_ptr, - float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, - int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp, - int32_t* row_sums, int row_sums_size, bool* compute_row_sums, - bool asymmetric_quantize_inputs, CpuBackendContext* context) { + const int8_t* projection_weights_ptr, + const uint8_t* projection_weights_ledger_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_aux_input, int n_output, int output_batch_leading_dim, + float* scratch0, float* scratch1, float* scratch2, float* scratch3, + float* input_sf, float* aux_input_sf, float* output_state_sf, + float* scaling_factors_scratch, float* recovered_cell_weights, + int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, + int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch, + float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, + float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp, + int32_t* output_state_zp, int32_t* row_sums, int row_sums_size, + bool* compute_row_sums, bool asymmetric_quantize_inputs, + CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we // can check the existence of only one to the get the condition. @@ -1106,11 +1151,12 @@ inline void LstmStepHybrid( // Calculate the input gate. (If not CIFG.) CalculateLstmGateHybrid( quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr, - input_to_input_weights_scale, input_to_input_row_sums, - quantized_aux_input_ptr, aux_input_sf, aux_input_zp, - aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale, - aux_input_to_input_row_sums, quantized_output_state_ptr, - output_state_sf, output_state_zp, recurrent_to_input_weights_ptr, + input_to_input_weights_ledger_ptr, input_to_input_weights_scale, + input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_input_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr, recurrent_to_input_weights_scale, recurrent_to_input_row_sums, cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale, input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch, @@ -1122,11 +1168,12 @@ inline void LstmStepHybrid( // Calculate the forget gate. CalculateLstmGateHybrid( quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr, - input_to_forget_weights_scale, input_to_forget_row_sums, - quantized_aux_input_ptr, aux_input_sf, aux_input_zp, - aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale, - aux_input_to_forget_row_sums, quantized_output_state_ptr, output_state_sf, - output_state_zp, recurrent_to_forget_weights_ptr, + input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale, + input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr, recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums, cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale, forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch, @@ -1137,11 +1184,12 @@ inline void LstmStepHybrid( // Calculate the cell update gate. CalculateLstmGateHybrid( quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr, - input_to_cell_weights_scale, input_to_cell_row_sums, - quantized_aux_input_ptr, aux_input_sf, aux_input_zp, - aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale, - aux_input_to_cell_row_sums, quantized_output_state_ptr, output_state_sf, - output_state_zp, recurrent_to_cell_weights_ptr, + input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale, + input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr, recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums, /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr, @@ -1157,11 +1205,12 @@ inline void LstmStepHybrid( // Calculate the output gate. CalculateLstmGateHybrid( quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr, - input_to_output_weights_scale, input_to_output_row_sums, - quantized_aux_input_ptr, aux_input_sf, aux_input_zp, - aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale, - aux_input_to_output_row_sums, quantized_output_state_ptr, output_state_sf, - output_state_zp, recurrent_to_output_weights_ptr, + input_to_output_weights_ledger_ptr, input_to_output_weights_scale, + input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, aux_input_to_output_row_sums, + quantized_output_state_ptr, output_state_sf, output_state_zp, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr, recurrent_to_output_weights_scale, recurrent_to_output_row_sums, cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale, output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch, @@ -1172,11 +1221,11 @@ inline void LstmStepHybrid( // Update the output state. CalculateLstmOutputHybrid( n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, - params->activation, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params->proj_clip, output_state_ptr, - asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums, - context, scratch2, quantized_output_scratch, input_sf, input_zp, - accum_scratch_ptr); + params->activation, projection_weights_ptr, projection_weights_ledger_ptr, + projection_weights_scale, projection_bias_ptr, params->proj_clip, + output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums, + compute_row_sums, context, scratch2, quantized_output_scratch, input_sf, + input_zp, accum_scratch_ptr); // Copy output state to the output. Note that the output's rows may not be // contiguous (output_batch_leading_dim != n_output). for (int b = 0; b < n_batch; b++) { @@ -1829,13 +1878,21 @@ TfLiteStatus EvalFloat( TfLiteStatus EvalHybrid( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_input_weights_ledger, const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_forget_weights_ledger, const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_cell_weights_ledger, const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* input_to_output_weights_ledger, const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_input_weights_ledger, const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_forget_weights_ledger, const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_cell_weights_ledger, const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* recurrent_to_output_weights_ledger, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, const TfLiteTensor* cell_to_output_weights, @@ -1850,9 +1907,11 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* aux_input_to_output_weights, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, - int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, + const TfLiteTensor* projection_weights, + const TfLiteTensor* projection_weights_ledger, + const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, int output_offset, + TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, @@ -1929,12 +1988,16 @@ TfLiteStatus EvalHybrid( GetTensorData(output) + t_rel * output_step + output_offset; LstmStepHybrid( input_ptr, GetTensorData(input_to_input_weights), + GetTensorData(input_to_input_weights_ledger), GetTensorScale(input_to_input_weights), GetTensorData(input_to_forget_weights), + GetTensorData(input_to_forget_weights_ledger), GetTensorScale(input_to_forget_weights), GetTensorData(input_to_cell_weights), + GetTensorData(input_to_cell_weights_ledger), GetTensorScale(input_to_cell_weights), GetTensorData(input_to_output_weights), + GetTensorData(input_to_output_weights_ledger), GetTensorScale(input_to_output_weights), aux_input_ptr, GetTensorData(aux_input_to_input_weights), GetTensorScale(aux_input_to_input_weights), @@ -1945,12 +2008,16 @@ TfLiteStatus EvalHybrid( GetTensorData(aux_input_to_output_weights), GetTensorScale(aux_input_to_output_weights), GetTensorData(recurrent_to_input_weights), + GetTensorData(recurrent_to_input_weights_ledger), GetTensorScale(recurrent_to_input_weights), GetTensorData(recurrent_to_forget_weights), + GetTensorData(recurrent_to_forget_weights_ledger), GetTensorScale(recurrent_to_forget_weights), GetTensorData(recurrent_to_cell_weights), + GetTensorData(recurrent_to_cell_weights_ledger), GetTensorScale(recurrent_to_cell_weights), GetTensorData(recurrent_to_output_weights), + GetTensorData(recurrent_to_output_weights_ledger), GetTensorScale(recurrent_to_output_weights), GetTensorData(cell_to_input_weights), GetTensorScale(cell_to_input_weights), @@ -1967,6 +2034,7 @@ TfLiteStatus EvalHybrid( GetTensorData(cell_gate_bias), GetTensorData(output_gate_bias), GetTensorData(projection_weights), + GetTensorData(projection_weights_ledger), GetTensorScale(projection_weights), GetTensorData(projection_bias), params, n_batch, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, @@ -2018,12 +2086,16 @@ TfLiteStatus EvalHybrid( LstmStepHybrid( input_ptr, GetTensorData(input_to_input_weights), + GetTensorData(input_to_input_weights_ledger), GetTensorScale(input_to_input_weights), GetTensorData(input_to_forget_weights), + GetTensorData(input_to_forget_weights_ledger), GetTensorScale(input_to_forget_weights), GetTensorData(input_to_cell_weights), + GetTensorData(input_to_cell_weights_ledger), GetTensorScale(input_to_cell_weights), GetTensorData(input_to_output_weights), + GetTensorData(input_to_output_weights_ledger), GetTensorScale(input_to_output_weights), aux_input_ptr, GetTensorData(aux_input_to_input_weights), GetTensorScale(aux_input_to_input_weights), @@ -2034,12 +2106,16 @@ TfLiteStatus EvalHybrid( GetTensorData(aux_input_to_output_weights), GetTensorScale(aux_input_to_output_weights), GetTensorData(recurrent_to_input_weights), + GetTensorData(recurrent_to_input_weights_ledger), GetTensorScale(recurrent_to_input_weights), GetTensorData(recurrent_to_forget_weights), + GetTensorData(recurrent_to_forget_weights_ledger), GetTensorScale(recurrent_to_forget_weights), GetTensorData(recurrent_to_cell_weights), + GetTensorData(recurrent_to_cell_weights_ledger), GetTensorScale(recurrent_to_cell_weights), GetTensorData(recurrent_to_output_weights), + GetTensorData(recurrent_to_output_weights_ledger), GetTensorScale(recurrent_to_output_weights), GetTensorData(cell_to_input_weights), GetTensorScale(cell_to_input_weights), @@ -2056,6 +2132,7 @@ TfLiteStatus EvalHybrid( GetTensorData(cell_gate_bias), GetTensorData(output_gate_bias), GetTensorData(projection_weights), + GetTensorData(projection_weights_ledger), GetTensorScale(projection_weights), GetTensorData(projection_bias), params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index d3fdf037b5c..5807c9ee56d 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -125,13 +125,21 @@ TfLiteStatus EvalFloat( TfLiteStatus EvalHybrid( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_input_weights_ledger, const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_forget_weights_ledger, const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_cell_weights_ledger, const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* input_to_output_weights_ledger, const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_input_weights_ledger, const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_forget_weights_ledger, const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_cell_weights_ledger, const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* recurrent_to_output_weights_ledger, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, const TfLiteTensor* cell_to_output_weights, @@ -146,9 +154,11 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* aux_input_to_output_weights, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, - int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, + const TfLiteTensor* projection_weights, + const TfLiteTensor* projection_weights_ledger, + const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, + bool forward_sequence, bool time_major, int output_offset, + TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, diff --git a/tensorflow/lite/kernels/lstm_eval_test.cc b/tensorflow/lite/kernels/lstm_eval_test.cc index adaa5db1e20..c7d935a4b4f 100644 --- a/tensorflow/lite/kernels/lstm_eval_test.cc +++ b/tensorflow/lite/kernels/lstm_eval_test.cc @@ -906,14 +906,14 @@ void TestOneHybridAsymmLSTM() { constexpr float kDefaultScale = 18.0; ops::builtin::lstm_eval::EvalHybrid( one_parameter.GetFloatInput(), - HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), - HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), + HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), nullptr, + HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), nullptr, /*cell_to_input_weights=*/nullptr, /*cell_to_forget_weights=*/nullptr, /*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(), @@ -926,7 +926,7 @@ void TestOneHybridAsymmLSTM() { /*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(), one_parameter.GetForgetBias(), one_parameter.GetCellBias(), one_parameter.GetOutputBias(), - HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), + HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), nullptr, one_parameter.GetProjectionBias(), ¶m, /*forward_sequence=*/true, /*time_major=*/true, diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index 16e28619daf..39e03afdfd8 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -2114,6 +2114,620 @@ TEST(LstmOpTest, InvalidTypes) { } #endif +class HybridSparseLSTMOpModel : public ::tflite::SingleOpModel { + public: + HybridSparseLSTMOpModel( + int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const TensorData& input_weights_td, + const std::vector& input_to_input_weights, + const std::vector& input_to_forget_weights, + const std::vector& input_to_cell_weights, + const std::vector& input_to_output_weights, + const TensorData& recurrent_weights_td, + const std::vector& recurrent_to_input_weights, + const std::vector& recurrent_to_forget_weights, + const std::vector& recurrent_to_cell_weights, + const std::vector& recurrent_to_output_weights, + const ::tflite::TensorType& weight_type = ::tflite::TensorType_INT8) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(::tflite::TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = + AddConstSparseInput(input_weights_td, input_to_input_weights, true); + } + + input_to_forget_weights_ = + AddConstSparseInput(input_weights_td, input_to_forget_weights, true); + + input_to_cell_weights_ = + AddConstSparseInput(input_weights_td, input_to_cell_weights, true); + + input_to_output_weights_ = + AddConstSparseInput(input_weights_td, input_to_output_weights, true); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddConstSparseInput( + recurrent_weights_td, recurrent_to_input_weights, true); + } + + recurrent_to_forget_weights_ = AddConstSparseInput( + recurrent_weights_td, recurrent_to_forget_weights, true); + recurrent_to_cell_weights_ = AddConstSparseInput( + recurrent_weights_td, recurrent_to_cell_weights, true); + recurrent_to_output_weights_ = AddConstSparseInput( + recurrent_weights_td, recurrent_to_output_weights, true); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(weight_type); + } + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32); + cell_bias_ = AddInput(::tflite::TensorType_FLOAT32); + output_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(weight_type); + if (use_projection_bias) { + projection_bias_ = AddInput(::tflite::TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + // Adding the 2 state tensors. + output_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32, + {n_output_ * n_batch_}}, + true); + cell_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32, + {n_cell_ * n_batch_}}, + true); + + if (use_cifg) { + input_layer_norm_weights_ = AddNullInput(); + } else { + input_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32); + } + forget_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32); + cell_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32); + output_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32); + + output_ = AddOutput(::tflite::TensorType_FLOAT32); + + SetBuiltinOp( + BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, + proj_clip, LSTMKernelType_FULL, false) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetCellToInputWeights(std::vector f) { + SignedSymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::vector f) { + SignedSymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::vector f) { + SignedSymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetInputLayerNormWeights(std::vector f) { + PopulateTensor(input_layer_norm_weights_, f); + } + + void SetForgetLayerNormWeights(std::vector f) { + PopulateTensor(forget_layer_norm_weights_, f); + } + + void SetCellLayerNormWeights(std::vector f) { + PopulateTensor(cell_layer_norm_weights_, f); + } + + void SetOutputLayerNormWeights(std::vector f) { + PopulateTensor(output_layer_norm_weights_, f); + } + + void SetInputGateBias(std::vector f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::vector f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::vector f) { PopulateTensor(cell_bias_, f); } + + void SetOutputGateBias(std::vector f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::vector f) { + SignedSymmetricQuantizeAndPopulate(projection_weights_, f); + } + + void SetProjectionBias(std::vector f) { + PopulateTensor(projection_bias_, f); + } + + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_layer_norm_weights_; + int forget_layer_norm_weights_; + int cell_layer_norm_weights_; + int output_layer_norm_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_state_; + int cell_state_; + + int output_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +class BaseSparseLstmTest : public ::testing::Test { + protected: + // Weights of the Sparse Layer Norm LSTM model. Some are optional. + std::vector input_to_input_weights_; + std::vector input_to_cell_weights_; + std::vector input_to_forget_weights_; + std::vector input_to_output_weights_; + std::vector input_gate_bias_; + std::vector cell_gate_bias_; + std::vector forget_gate_bias_; + std::vector output_gate_bias_; + std::vector recurrent_to_input_weights_; + std::vector recurrent_to_cell_weights_; + std::vector recurrent_to_forget_weights_; + std::vector recurrent_to_output_weights_; + std::vector cell_to_input_weights_; + std::vector cell_to_forget_weights_; + std::vector cell_to_output_weights_; + std::vector input_layer_norm_weights_; + std::vector forget_layer_norm_weights_; + std::vector cell_layer_norm_weights_; + std::vector output_layer_norm_weights_; + std::vector projection_weights_; + + std::vector input_to_input_weights_size_; + std::vector input_to_cell_weights_size_; + std::vector input_to_forget_weights_size_; + std::vector input_to_output_weights_size_; + std::vector recurrent_to_input_weights_size_; + std::vector recurrent_to_cell_weights_size_; + std::vector recurrent_to_forget_weights_size_; + std::vector recurrent_to_output_weights_size_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; + float cell_clip_; + float proj_clip_; + + // Layer Norm LSTM input is stored as num_batch x num_inputs vector. + std::vector> sparse_layer_norm_lstm_input_; + + // Compares output up to tolerance to the result of the layer_norm_lstm given + // the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + HybridSparseLSTMOpModel* sparse_layer_norm_lstm, + float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = sparse_layer_norm_lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + sparse_layer_norm_lstm->SetInput( + b * sparse_layer_norm_lstm->num_inputs(), batch_start, batch_end); + } + + sparse_layer_norm_lstm->Invoke(); + + const int num_outputs = sparse_layer_norm_lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT( + sparse_layer_norm_lstm->GetOutput(), + ElementsAreArray(::tflite::ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgPeepholeProjectionNoClippingSparseLstmTest + : public BaseSparseLstmTest { + void SetUp() override { + n_batch_ = 2; + n_input_ = 48; + n_cell_ = 4; + n_output_ = 16; + cell_clip_ = 0.0; + proj_clip_ = 0.0; + + /* clang-format off */ + input_to_input_weights_ = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, + 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0, + /* 2nd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, + -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 3rd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, + -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37, + 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0}; + input_to_input_weights_size_ = {4, 48}; + + input_to_forget_weights_ = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, + 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0, + /* 2nd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, + -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 3rd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, + -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37, + 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0}; + input_to_forget_weights_size_ = {4, 48}; + + input_to_cell_weights_ = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, + 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0, + /* 2nd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, + -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 3rd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, + -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37, + 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0}; + input_to_cell_weights_size_ = {4, 48}; + + input_to_output_weights_ = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38, + 39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0, + /* 2nd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24, + -25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 3rd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25, + -26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37, + 38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0}; + input_to_output_weights_size_ = {4, 48}; + + input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_input_weights_ = { + -0.2, -0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 1st row + 0.1, -0.5, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 2nd row + -0.2, -0.3, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 3rd row + 0.05, -0.2, -0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 4th row + }; + recurrent_to_input_weights_size_ = {4, 16}; + + recurrent_to_cell_weights_ = { + -0.3, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 1st row + -0.3, 0.8, -0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 2nd row + -0.2, 0.3, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 3rd row + -0.6, -0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 4th row + }; + recurrent_to_cell_weights_size_ = {4, 16}; + + recurrent_to_forget_weights_ = { + -0.5, -0.3, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 1st row + -0.2, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 2nd row + 0.9, 0.3, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 3rd row + 0.2, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 4th row + }; + recurrent_to_forget_weights_size_ = {4, 16}; + + recurrent_to_output_weights_ = { + 0.3, -0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 1st row + -0.2, -0.5, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 2nd row + -0.2, -0.6, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 3rd row + -0.4, -0.7, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, // 4th row + }; + recurrent_to_output_weights_size_ = {4, 16}; + + cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5}; + + projection_weights_ = { + -0.1, 0.2, 0.01, -0.2, // 1st row + 0.1, 0.5, 0.3, 0.08, // 2nd row + 0.07, 0.2, -0.4, 0.2, // 3rd row + 0.0, 0.0, 0.0, 0.0, // 4th row + 0.0, 0.0, 0.0, 0.0, // 5th row + 0.0, 0.0, 0.0, 0.0, // 6th row + 0.0, 0.0, 0.0, 0.0, // 7th row + 0.0, 0.0, 0.0, 0.0, // 8th row + 0.0, 0.0, 0.0, 0.0, // 9th row + 0.0, 0.0, 0.0, 0.0, // 10th row + 0.0, 0.0, 0.0, 0.0, // 11th row + 0.0, 0.0, 0.0, 0.0, // 12th row + 0.0, 0.0, 0.0, 0.0, // 13th row + 0.0, 0.0, 0.0, 0.0, // 14th row + 0.0, 0.0, 0.0, 0.0, // 15th row + 0.0, 0.0, 0.0, 0.0, // 16th row + 0.0, 0.0, 0.0, 0.0, // 17th row + 0.0, 0.0, 0.0, 0.0, // 18th row + }; + + sparse_layer_norm_lstm_input_ = { + // Batch0: 2 (input_sequence_size) * 45 (n_input_) + { + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, + -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, + -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0 + 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0, + -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, + 0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, + 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // seq 1 + }, + // Batch1: 2 (input_sequence_size) * 45 (n_input_) + { + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, + -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, + -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0 + 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0, + -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, + 0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, + 1.0, -2.5, 0.7, -1.9, 0.2, -1.0, 1.0, -1.0, // seq 1 + }, + }; + /* clang-format on */ + } +}; + +TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest, + HybridSparseLstmBlackBoxTest) { + TensorData input_weight = {}; + input_weight.type = TensorType_FLOAT32; + input_weight.shape = {4, 48}; + input_weight.traversal_order = {0, 1, 2}; + input_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR}; + input_weight.block_map = {1}; + input_weight.block_size = {16}; + TensorData recurrent_weight = {}; + recurrent_weight.type = TensorType_FLOAT32; + recurrent_weight.shape = {4, 16}; + recurrent_weight.traversal_order = {0, 1, 2}; + recurrent_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR}; + recurrent_weight.block_map = {1}; + recurrent_weight.block_size = {16}; + HybridSparseLSTMOpModel sparse_layer_norm_lstm( + n_batch_, n_input_, n_cell_, n_output_, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, cell_clip_, proj_clip_, + { + {n_batch_, n_input_}, // input tensor + + {input_to_input_weights_size_}, + {input_to_forget_weights_size_}, + {input_to_cell_weights_size_}, + {input_to_output_weights_size_}, + + {recurrent_to_input_weights_size_}, + {recurrent_to_forget_weights_size_}, + {recurrent_to_cell_weights_size_}, + {recurrent_to_output_weights_size_}, + + {n_cell_}, // cell_to_input_weight tensor + {n_cell_}, // cell_to_forget_weight tensor + {n_cell_}, // cell_to_output_weight tensor + + {n_cell_}, // input_gate_bias tensor + {n_cell_}, // forget_gate_bias tensor + {n_cell_}, // cell_bias tensor + {n_cell_}, // output_gate_bias tensor + + {n_output_, n_cell_}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_output_ * n_batch_}, // output_state tensor + {n_cell_ * n_batch_}, // cell_state tensor + + {n_cell_}, // input_layer_norm_weight tensor + {n_cell_}, // forget_layer_norm_weight tensor + {n_cell_}, // cell_layer_norm_weight tensor + {n_cell_}, // output_layer_norm_weight tensor + }, + input_weight, input_to_input_weights_, input_to_forget_weights_, + input_to_cell_weights_, input_to_output_weights_, recurrent_weight, + recurrent_to_input_weights_, recurrent_to_forget_weights_, + recurrent_to_cell_weights_, recurrent_to_output_weights_); + + sparse_layer_norm_lstm.SetInputGateBias(input_gate_bias_); + sparse_layer_norm_lstm.SetCellBias(cell_gate_bias_); + sparse_layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + sparse_layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + sparse_layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + sparse_layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + sparse_layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + sparse_layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); + sparse_layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); + sparse_layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); + sparse_layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); + + sparse_layer_norm_lstm.SetProjectionWeights(projection_weights_); + + /* clang-format off */ + const std::vector> sparse_layer_norm_lstm_golden_output = { + { + // Batch0: 2 (input_sequence_size) * 3 (n_output_) + 0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output_) + 0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, + }}; + /* clang-format on */ + + VerifyGoldens(sparse_layer_norm_lstm_input_, + sparse_layer_norm_lstm_golden_output, &sparse_layer_norm_lstm); +} + // Test parameter controls asymmetric_quantize_inputs in LSTMOpModel. INSTANTIATE_TEST_SUITE_P( Parameterized, LstmOpTest, diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 7785b6cba27..63193b021aa 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -214,9 +214,82 @@ class SingleOpModel { return AddConstInput(TensorData{type, shape}, data); } + // TODO(b/166202747): Use a better way to do type specialization. Reduce + // duplicate code in the two functions below. + int AddConstSparseInput(const TensorData& t, + const std::vector& data) { + int id = tensors_.size(); + const int dims_count = t.traversal_order.size(); + std::vector dense_data(data); + + tflite::optimize::sparsity::FormatConverter converter( + t.shape, t.traversal_order, t.format, t.block_size, t.block_map); + converter.DenseToSparse(dense_data.data()); + + const auto dim_metadata = converter.GetDimMetadata(); + const auto sparse_data = converter.GetData(); + + // Build sparsity parameter. + std::vector> fb_dim_metadata( + dims_count); + for (int i = 0; i < dims_count; i++) { + const int metadata_idx = 2 * i; + if (i < t.shape.size() && + t.format[t.traversal_order[i]] == kTfLiteDimSparseCSR) { + auto array_segments = + CreateInt32Vector(builder_, + builder_.CreateVector(dim_metadata[metadata_idx])) + .Union(); + auto array_indices = + CreateInt32Vector( + builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1])) + .Union(); + fb_dim_metadata[i] = CreateDimensionMetadata( + builder_, DimensionType_SPARSE_CSR, 0, + SparseIndexVector_Int32Vector, array_segments, + SparseIndexVector_Int32Vector, array_indices); + } else { + fb_dim_metadata[i] = CreateDimensionMetadata( + builder_, DimensionType_DENSE, dim_metadata[metadata_idx][0]); + } + } + + flatbuffers::Offset s_param = CreateSparsityParameters( + builder_, builder_.CreateVector(t.traversal_order), + builder_.CreateVector(t.block_map), + builder_.CreateVector(fb_dim_metadata)); + + int buffer_id = 0; + if (!data.empty()) { + // Initialize buffers list with empty buffer to allow for non-const + // tensors. + if (buffers_.empty()) { + buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({}))); + } + + // Add compressed data as a Buffer to buffers list. + buffer_id = buffers_.size(); + auto data_buffer = builder_.CreateVector( + reinterpret_cast(sparse_data.data()), + sparse_data.size()); + buffers_.push_back(CreateBuffer(builder_, data_buffer)); + } + + tensors_.push_back(CreateTensor( + builder_, builder_.CreateVector(t.shape), t.type, + /*buffer=*/buffer_id, + /*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param)); + + inputs_.push_back(id); + tensor_data_[id] = t; + + return id; + } + // Add a constant sparse tensor as input. template - int AddConstSparseInput(const TensorData& t, std::initializer_list data) { + int AddConstSparseInput(const TensorData& t, const std::vector& data, + bool symmetric_quantize = false) { int id = tensors_.size(); const int dims_count = t.traversal_order.size(); std::vector dense_data(data); @@ -258,8 +331,9 @@ class SingleOpModel { builder_.CreateVector(t.block_map), builder_.CreateVector(fb_dim_metadata)); + flatbuffers::Offset q_params = 0; int buffer_id = 0; - if (data.size()) { + if (!data.empty()) { // Initialize buffers list with empty buffer to allow for non-const // tensors. if (buffers_.empty()) { @@ -268,16 +342,31 @@ class SingleOpModel { // Add compressed data as a Buffer to buffers list. buffer_id = buffers_.size(); - auto data_buffer = builder_.CreateVector( - reinterpret_cast(sparse_data.data()), - sizeof(T) * sparse_data.size()); - buffers_.push_back(CreateBuffer(builder_, data_buffer)); + if (symmetric_quantize) { + const int length = sparse_data.size(); + std::vector q(length); + float min, max, scaling_factor; + tensor_utils::SymmetricQuantizeFloats( + sparse_data.data(), length, q.data(), &min, &max, &scaling_factor); + q_params = CreateQuantizationParameters( + builder_, 0, 0, builder_.CreateVector({scaling_factor}), + builder_.CreateVector({0})); + auto data_buffer = builder_.CreateVector( + reinterpret_cast(q.data()), q.size()); + buffers_.push_back(CreateBuffer(builder_, data_buffer)); + } else { + auto data_buffer = builder_.CreateVector( + reinterpret_cast(sparse_data.data()), + sizeof(T) * sparse_data.size()); + buffers_.push_back(CreateBuffer(builder_, data_buffer)); + } } - tensors_.push_back(CreateTensor( - builder_, builder_.CreateVector(t.shape), t.type, - /*buffer=*/buffer_id, - /*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param)); + tensors_.push_back( + CreateTensor(builder_, builder_.CreateVector(t.shape), + symmetric_quantize ? TensorType_INT8 : t.type, + /*buffer=*/buffer_id, + /*name=*/0, q_params, /*is_variable=*/false, s_param)); inputs_.push_back(id); tensor_data_[id] = t; diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 0849c6dc0e4..d6c9fb93d0a 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -650,11 +650,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); const int row_sums_size = row_sums->dims->data[0]; return lstm_eval::EvalHybrid( - input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, - cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input, input_to_input_weights, + /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, + /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, + /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, + /*input_to_output_weights_ledger*/ nullptr, + recurrent_to_input_weights, + /*recurrent_to_input_weights_ledger*/ nullptr, + recurrent_to_forget_weights, + /*recurrent_to_forget_weights_ledger*/ nullptr, + recurrent_to_cell_weights, + /*recurrent_to_cell_weights_ledger*/ nullptr, + recurrent_to_output_weights, + /*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights, + cell_to_forget_weights, cell_to_output_weights, input_layer_norm_coefficients, forget_layer_norm_coefficients, cell_layer_norm_coefficients, output_layer_norm_coefficients, /*aux_input=*/nullptr, @@ -663,7 +672,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias, - projection_weights, projection_bias, &lstm_params, + projection_weights, /*projection_weights_ledger*/ nullptr, + projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, /*output_offset=*/0, scratch_buffer, GetTemporary(context, node, kInputScalingFactors),