Add builtin sparse LSTM kernel.

PiperOrigin-RevId: 329562447
Change-Id: I5c407b513fbc86d21f6ea2d626da7b69dcd38bc7
This commit is contained in:
Yunlu Li 2020-09-01 12:48:41 -07:00 committed by TensorFlower Gardener
parent b4ee2c4294
commit 8744e4b2b9
10 changed files with 1230 additions and 119 deletions

View File

@ -1136,10 +1136,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const int fw_row_sums_size = fw_row_sums->dims->data[0];
const int bw_row_sums_size = bw_row_sums->dims->data[0];
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
input, fw_input_to_input_weights,
/*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
/*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
/*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
/*input_to_output_weights_ledger*/ nullptr,
fw_recurrent_to_input_weights,
/*recurrent_to_input_weights_ledger*/ nullptr,
fw_recurrent_to_forget_weights,
/*recurrent_to_forget_weights_ledger*/ nullptr,
fw_recurrent_to_cell_weights,
/*recurrent_to_cell_weights_ledger*/ nullptr,
fw_recurrent_to_output_weights,
/*recurrent_to_output_weights_ledger*/ nullptr,
fw_cell_to_input_weights, fw_cell_to_forget_weights,
fw_cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
@ -1149,7 +1158,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
fw_output_gate_bias, fw_projection_weights,
/*projection_weights_ledger*/ nullptr, fw_projection_bias,
&lstm_params,
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
@ -1167,10 +1177,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, fw_pass_status);
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
bw_input, bw_input_to_input_weights,
/*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
/*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
/*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
/*input_to_output_weights_ledger*/ nullptr,
bw_recurrent_to_input_weights,
/*recurrent_to_input_weights_ledger*/ nullptr,
bw_recurrent_to_forget_weights,
/*recurrent_to_forget_weights_ledger*/ nullptr,
bw_recurrent_to_cell_weights,
/*recurrent_to_cell_weights_ledger*/ nullptr,
bw_recurrent_to_output_weights,
/*recurrent_to_output_weights_ledger*/ nullptr,
bw_cell_to_input_weights, bw_cell_to_forget_weights,
bw_cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
@ -1180,7 +1199,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
bw_output_gate_bias, bw_projection_weights,
/*projection_weights_ledger*/ nullptr, bw_projection_bias,
&lstm_params,
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),

View File

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <memory>
#include <vector>
@ -42,7 +41,7 @@ using ::testing::ElementsAreArray;
template <typename T>
class DensifyOpModel : public SingleOpModel {
public:
DensifyOpModel(const TensorData& input, std::initializer_list<T> input_data,
DensifyOpModel(const TensorData& input, const std::vector<T>& input_data,
int version = 1) {
input_ = AddConstSparseInput(input, input_data);
output_ = AddOutput({input.type, input.shape});
@ -65,9 +64,8 @@ class DensifyOpModel : public SingleOpModel {
};
TEST(DensifyOpTest, Float) {
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
0, 0, 5, 0, 0, 7};
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
std::vector<float> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
std::vector<float> sparse_values = {6, 9, 8, 5, 7};
TensorData input = {};
input.type = TensorType_FLOAT32;
input.shape = {3, 4};
@ -80,9 +78,8 @@ TEST(DensifyOpTest, Float) {
}
TEST(DensifyOpTest, Float3D) {
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
0, 0, 5, 0, 0, 7};
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
std::vector<float> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
std::vector<float> sparse_values = {6, 9, 8, 5, 7};
TensorData input = {};
input.type = TensorType_FLOAT32;
input.shape = {3, 2, 2};
@ -95,9 +92,8 @@ TEST(DensifyOpTest, Float3D) {
}
TEST(DensifyOpTest, Int8) {
std::initializer_list<int8_t> dense_values = {6, 0, 9, 8, 0, 0,
0, 0, 5, 0, 0, 7};
std::initializer_list<int8_t> sparse_values = {6, 9, 8, 5, 7};
std::vector<int8_t> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
std::vector<int8_t> sparse_values = {6, 9, 8, 5, 7};
TensorData input = {};
input.type = TensorType_INT8;
input.shape = {3, 4};

View File

@ -1144,7 +1144,7 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
int batches, const TensorData& input,
const TensorData& weights,
std::initializer_list<T> weights_data,
const std::vector<T>& weights_data,
int num_threads = 1)
: batches_(batches), units_(units) {
int total_input_size = 1;

View File

@ -55,6 +55,10 @@ struct OpData {
int scratch_tensor_index;
lstm_eval::IntegerLstmParameter integer_lstm_param;
bool compute_row_sums;
// Only used for sparse hybrid lstm kernels.
int ledger_index;
bool ledger_initialized;
};
namespace full {
@ -77,6 +81,63 @@ enum HybridTemporaryTensor {
kNumHybridTemporaryTensors = 12,
};
constexpr int kLedgersToAdd = 9;
constexpr int kInputToInputWeightsLedgerOffset = 0;
constexpr int kInputToForgetWeightsLedgerOffset = 1;
constexpr int kInputToCellWeightsLedgerOffset = 2;
constexpr int kInputToOutputWeightsLedgerOffset = 3;
constexpr int kRecurrentToInputWeightsLedgerOffset = 4;
constexpr int kRecurrentToForgetWeightsLedgerOffset = 5;
constexpr int kRecurrentToCellWeightsLedgerOffset = 6;
constexpr int kRecurrentToOutputWeightsLedgerOffset = 7;
constexpr int kProjectionWeightsLedgerOffset = 8;
TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context,
TfLiteTensor* ledger) {
ledger->type = kTfLiteUInt8;
ledger->allocation_type = kTfLiteArenaRwPersistent;
if (sparsity == nullptr) {
return kTfLiteOk;
}
TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
sparsity->dim_metadata[1].array_segments->size - 1;
return context->ResizeTensor(context, ledger, ledger_size);
}
TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) {
if (sparsity == nullptr) {
return kTfLiteOk;
}
const auto* array_segments = sparsity->dim_metadata[1].array_segments;
const auto* array_indices = sparsity->dim_metadata[1].array_indices;
uint8_t* output_data = GetTensorData<uint8_t>(ledger);
int output_data_ptr = 0;
for (int i = 0; i < array_segments->size - 1; i++) {
int row_start = array_segments->data[i];
int row_end = array_segments->data[i + 1];
if (row_end - row_start > UINT8_MAX) {
return kTfLiteError;
}
// Copy num of non-zero blocks in row i.
output_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
output_data_ptr++;
for (int j = row_start; j < row_end; j++) {
if (array_indices->data[j] > UINT8_MAX) {
return kTfLiteError;
}
// Copy indices of non-zero blocks in row i.
output_data[output_data_ptr] =
static_cast<uint8_t>(array_indices->data[j]);
output_data_ptr++;
}
}
return kTfLiteOk;
}
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
TfLiteContext* context, TfLiteNode* node,
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
@ -744,6 +805,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// TODO(b/159066113): maybe just add the minimum required temp tensors?
context->AddTensors(context, kNumHybridTemporaryTensors,
&op_data->scratch_tensor_index);
// Tensors used for the sparse hybrid kernel.
context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd,
&op_data->ledger_index);
return op_data;
}
@ -1239,6 +1303,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The weights are of consistent type, so it suffices to check one.
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr);
// The type of Integer LSTM.
const int num_intermediate_tensors = node->intermediates->size;
if (is_integer) {
@ -1251,7 +1317,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
if (is_sparse_op) {
node->temporaries =
TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd);
} else {
node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
}
} else if (is_integer) {
if (is_8x8_16) {
node->temporaries = TfLiteIntArrayCreate(6);
@ -1289,7 +1360,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
if (is_hybrid_op) {
op_data->compute_row_sums = true;
if (!is_sparse_op) {
op_data->compute_row_sums = true;
}
// Allocate temporary tensors to store quantized values of input,
// output_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
@ -1454,6 +1527,125 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
if (is_sparse_op) {
op_data->ledger_initialized = false;
int offset = kNumHybridTemporaryTensors;
{
node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] =
op_data->ledger_index + kInputToInputWeightsLedgerOffset;
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
TfLiteTensor* input_to_input_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToInputWeightsLedgerOffset];
auto status = make_ledger(input_to_input_weights == nullptr
? nullptr
: input_to_input_weights->sparsity,
context, input_to_input_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] =
op_data->ledger_index + kInputToForgetWeightsLedgerOffset;
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
TfLiteTensor* input_to_forget_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToForgetWeightsLedgerOffset];
auto status = make_ledger(input_to_forget_weights->sparsity, context,
input_to_forget_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] =
op_data->ledger_index + kInputToCellWeightsLedgerOffset;
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
TfLiteTensor* input_to_cell_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToCellWeightsLedgerOffset];
auto status = make_ledger(input_to_cell_weights->sparsity, context,
input_to_cell_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] =
op_data->ledger_index + kInputToOutputWeightsLedgerOffset;
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
TfLiteTensor* input_to_output_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToOutputWeightsLedgerOffset];
auto status = make_ledger(input_to_output_weights->sparsity, context,
input_to_output_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] =
op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset;
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, kRecurrentToInputWeightsTensor);
TfLiteTensor* recurrent_to_input_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToInputWeightsLedgerOffset];
auto status = make_ledger(recurrent_to_input_weights == nullptr
? nullptr
: recurrent_to_input_weights->sparsity,
context, recurrent_to_input_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries
->data[offset + kRecurrentToForgetWeightsLedgerOffset] =
op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset;
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
TfLiteTensor* recurrent_to_forget_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToForgetWeightsLedgerOffset];
auto status = make_ledger(recurrent_to_forget_weights->sparsity,
context, recurrent_to_forget_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] =
op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset;
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
TfLiteTensor* recurrent_to_cell_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToCellWeightsLedgerOffset];
auto status = make_ledger(recurrent_to_cell_weights->sparsity, context,
recurrent_to_cell_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries
->data[offset + kRecurrentToOutputWeightsLedgerOffset] =
op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset;
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
TfLiteTensor* recurrent_to_output_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToOutputWeightsLedgerOffset];
auto status = make_ledger(recurrent_to_output_weights->sparsity,
context, recurrent_to_output_weights_ledger);
if (status != kTfLiteOk) return status;
}
{
node->temporaries->data[offset + kProjectionWeightsLedgerOffset] =
op_data->ledger_index + kProjectionWeightsLedgerOffset;
const TfLiteTensor* projection_weights =
GetInput(context, node, kProjectionWeightsTensor);
TfLiteTensor* projection_weights_ledger =
&context->tensors[op_data->ledger_index +
kProjectionWeightsLedgerOffset];
auto status = make_ledger(projection_weights->sparsity, context,
projection_weights_ledger);
if (status != kTfLiteOk) return status;
}
}
}
if (is_integer) {
@ -1624,14 +1816,116 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
case kTfLiteInt8: {
const bool is_hybrid = (input->type == kTfLiteFloat32);
const bool is_sparse = input_to_output_weights->sparsity != nullptr;
if (is_hybrid) {
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
const int row_sums_size = row_sums->dims->data[0];
if (is_sparse) {
TfLiteTensor* input_to_input_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToInputWeightsLedgerOffset];
TfLiteTensor* input_to_forget_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToForgetWeightsLedgerOffset];
TfLiteTensor* input_to_cell_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToCellWeightsLedgerOffset];
TfLiteTensor* input_to_output_weights_ledger =
&context->tensors[op_data->ledger_index +
kInputToOutputWeightsLedgerOffset];
TfLiteTensor* recurrent_to_input_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToInputWeightsLedgerOffset];
TfLiteTensor* recurrent_to_forget_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToForgetWeightsLedgerOffset];
TfLiteTensor* recurrent_to_cell_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToCellWeightsLedgerOffset];
TfLiteTensor* recurrent_to_output_weights_ledger =
&context->tensors[op_data->ledger_index +
kRecurrentToOutputWeightsLedgerOffset];
TfLiteTensor* projection_weights_ledger =
&context->tensors[op_data->ledger_index +
kProjectionWeightsLedgerOffset];
if (!op_data->ledger_initialized) {
copy_ledger(input_to_input_weights == nullptr
? nullptr
: input_to_input_weights->sparsity,
input_to_input_weights_ledger);
copy_ledger(input_to_forget_weights->sparsity,
input_to_forget_weights_ledger);
copy_ledger(input_to_cell_weights->sparsity,
input_to_cell_weights_ledger);
copy_ledger(input_to_output_weights->sparsity,
input_to_output_weights_ledger);
copy_ledger(recurrent_to_input_weights == nullptr
? nullptr
: recurrent_to_input_weights->sparsity,
recurrent_to_input_weights_ledger);
copy_ledger(recurrent_to_forget_weights->sparsity,
recurrent_to_forget_weights_ledger);
copy_ledger(recurrent_to_cell_weights->sparsity,
recurrent_to_cell_weights_ledger);
copy_ledger(recurrent_to_output_weights->sparsity,
recurrent_to_output_weights_ledger);
copy_ledger(projection_weights->sparsity,
projection_weights_ledger);
op_data->ledger_initialized = true;
}
return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_input_weights_ledger,
input_to_forget_weights, input_to_forget_weights_ledger,
input_to_cell_weights, input_to_cell_weights_ledger,
input_to_output_weights, input_to_output_weights_ledger,
recurrent_to_input_weights, recurrent_to_input_weights_ledger,
recurrent_to_forget_weights, recurrent_to_forget_weights_ledger,
recurrent_to_cell_weights, recurrent_to_cell_weights_ledger,
recurrent_to_output_weights, recurrent_to_output_weights_ledger,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_weights_ledger, projection_bias,
params,
/*forward_sequence=*/true, /*time_major=*/true,
/*output_offset=*/0, GetTemporary(context, node, kScratchBuffer),
GetTemporary(context, node, kInputScalingFactors),
/*aux_input_sf=*/nullptr,
GetTemporary(context, node, kOutputStateScalingFactors),
GetTemporary(context, node, kProductScalingFactors),
GetTemporary(context, node, kRecoveredCellWeights),
GetTemporary(context, node, kInputQuantized),
/*aux_input_quantized=*/nullptr,
GetTemporary(context, node, kOutputStateQuantized),
GetTemporary(context, node, kCellStateQuantized), output_state,
cell_state, GetTemporary(context, node, kAccumScratch), output,
GetTemporary(context, node, kInputZeroPoints),
/*aux_input_zp=*/nullptr,
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
row_sums_size, &op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
}
return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
input, input_to_input_weights,
/*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
/*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
/*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
/*input_to_output_weights_ledger*/ nullptr,
recurrent_to_input_weights,
/*recurrent_to_input_weights_ledger*/ nullptr,
recurrent_to_forget_weights,
/*recurrent_to_forget_weights_ledger*/ nullptr,
recurrent_to_cell_weights,
/*recurrent_to_cell_weights_ledger*/ nullptr,
recurrent_to_output_weights,
/*recurrent_to_output_weights_ledger*/ nullptr,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
@ -1641,7 +1935,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, params,
projection_weights, /*projection_weights_ledger*/ nullptr,
projection_bias, params,
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
GetTemporary(context, node, kScratchBuffer),
GetTemporary(context, node, kInputScalingFactors),

View File

@ -312,6 +312,7 @@ void CalculateLstmGateHybrid(
// Input and weights
const int8_t* input, const float* input_sf, const int32_t* input_zp,
const int8_t* input_to_gate_weights,
const uint8_t* input_to_gate_weights_ledger,
const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
// Aux input and weights
const int8_t* aux_input, const float* aux_input_sf,
@ -321,6 +322,7 @@ void CalculateLstmGateHybrid(
// Output state and weights
const int8_t* output_state, const float* output_state_sf,
const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
const uint8_t* recurrent_to_gate_weights_ledger,
const float recurrent_to_gate_weights_scale,
int32_t* recurrent_to_gate_row_sums,
// Cell state and weights (peephole LSTM)
@ -356,11 +358,22 @@ void CalculateLstmGateHybrid(
// For each batch and cell: compute input_weight * input.
// Skip if input is all zeros.
if (!is_input_all_zeros) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_gate_weights, n_cell, n_input, input,
input_to_gate_weights_scale, input_sf, n_batch, gate,
/*per_channel_scale=*/nullptr, input_zp, accum_scratch,
input_to_gate_row_sums, compute_row_sums, scratch0, context);
if (input_to_gate_weights_ledger != nullptr) {
std::vector<float> scales(n_batch);
for (int i = 0; i < n_batch; i++) {
scales[i] = input_to_gate_weights_scale * input_sf[i];
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
input, scales.data(), n_batch, gate);
} else {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_gate_weights, n_cell, n_input, input,
input_to_gate_weights_scale, input_sf, n_batch, gate,
/*per_channel_scale=*/nullptr, input_zp, accum_scratch,
input_to_gate_row_sums, compute_row_sums, scratch0, context);
}
}
// For each batch and cell: compute aux_input_weight * aux_input.
// Skip if auxiliary input is not available or all zeros.
@ -374,11 +387,21 @@ void CalculateLstmGateHybrid(
// For each batch and cell: compute recurrent_weight * output_state.
// Skip if output state is all zeros.
if (!is_output_state_all_zeros) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_gate_weights, n_cell, n_output, output_state,
recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
if (recurrent_to_gate_weights_ledger != nullptr) {
std::vector<float> scales(n_batch);
for (int i = 0; i < n_batch; i++) {
scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
n_output, output_state, scales.data(), n_batch, gate);
} else {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_gate_weights, n_cell, n_output, output_state,
recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
}
}
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
if (use_peephole) {
@ -422,11 +445,12 @@ void CalculateLstmGateHybrid(
void CalculateLstmOutputHybrid(
int n_batch, int n_cell, int n_output, const float* cell_state,
const float* output_gate, TfLiteFusedActivation activation,
const int8_t* projection_weights, float projection_weights_scale,
const float* projection_bias, const float proj_clip, float* output_state,
bool asymmetric_quantize_inputs, int32_t* projection_weights_row_sums,
bool* compute_row_sums, CpuBackendContext* context, float* scratch0,
int8_t* scratch1, float* scratch2, int32_t* scratch3, int32_t* scratch4) {
const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
float projection_weights_scale, const float* projection_bias,
const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
int32_t* projection_weights_row_sums, bool* compute_row_sums,
CpuBackendContext* context, float* scratch0, int8_t* scratch1,
float* scratch2, int32_t* scratch3, int32_t* scratch4) {
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
activation, scratch0);
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
@ -447,11 +471,21 @@ void CalculateLstmOutputHybrid(
tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
scratch2, scratch3,
asymmetric_quantize_inputs);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights, n_output, n_cell, scratch1,
projection_weights_scale, scratch2, n_batch, output_state,
/*per_channel_scale=*/nullptr, scratch3, scratch4,
projection_weights_row_sums, compute_row_sums, scratch2, context);
if (projection_weights_ledger != nullptr) {
std::vector<float> scales(n_batch);
for (int i = 0; i < n_batch; i++) {
scales[i] = projection_weights_scale * scratch2[i];
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
projection_weights, projection_weights_ledger, n_output, n_cell,
scratch1, scales.data(), n_batch, output_state);
} else {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights, n_output, n_cell, scratch1,
projection_weights_scale, scratch2, n_batch, output_state,
/*per_channel_scale=*/nullptr, scratch3, scratch4,
projection_weights_row_sums, compute_row_sums, scratch2, context);
}
}
if (proj_clip > 0.0f) {
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
@ -955,11 +989,16 @@ inline void LstmStepFloat(
// output_ptr - size 'n_batch * output_batch_leading_dim'
inline void LstmStepHybrid(
const float* input_ptr, const int8_t* input_to_input_weights_ptr,
const uint8_t* input_to_input_weights_ledger_ptr,
float input_to_input_weights_scale,
const int8_t* input_to_forget_weights_ptr,
const uint8_t* input_to_forget_weights_ledger_ptr,
float input_to_forget_weights_scale,
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
const int8_t* input_to_cell_weights_ptr,
const uint8_t* input_to_cell_weights_ledger_ptr,
float input_to_cell_weights_scale,
const int8_t* input_to_output_weights_ptr,
const uint8_t* input_to_output_weights_ledger_ptr,
float input_to_output_weights_scale, const float* aux_input_ptr,
const int8_t* aux_input_to_input_weights_ptr,
float aux_input_to_input_weights_scale,
@ -970,12 +1009,16 @@ inline void LstmStepHybrid(
const int8_t* aux_input_to_output_weights_ptr,
float aux_input_to_output_weights_scale,
const int8_t* recurrent_to_input_weights_ptr,
const uint8_t* recurrent_to_input_weights_ledger_ptr,
float recurrent_to_input_weights_scale,
const int8_t* recurrent_to_forget_weights_ptr,
const uint8_t* recurrent_to_forget_weights_ledger_ptr,
float recurrent_to_forget_weights_scale,
const int8_t* recurrent_to_cell_weights_ptr,
const uint8_t* recurrent_to_cell_weights_ledger_ptr,
float recurrent_to_cell_weights_scale,
const int8_t* recurrent_to_output_weights_ptr,
const uint8_t* recurrent_to_output_weights_ledger_ptr,
float recurrent_to_output_weights_scale,
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
const int8_t* cell_to_forget_weights_ptr,
@ -988,19 +1031,21 @@ inline void LstmStepHybrid(
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
const int8_t* projection_weights_ptr, float projection_weights_scale,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* scratch0, float* scratch1,
float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf,
float* output_state_sf, float* scaling_factors_scratch,
float* recovered_cell_weights, int8_t* quantized_input_ptr,
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_output_scratch, float* output_state_ptr,
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp,
int32_t* row_sums, int row_sums_size, bool* compute_row_sums,
bool asymmetric_quantize_inputs, CpuBackendContext* context) {
const int8_t* projection_weights_ptr,
const uint8_t* projection_weights_ledger_ptr,
float projection_weights_scale, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* scratch0, float* scratch1, float* scratch2, float* scratch3,
float* input_sf, float* aux_input_sf, float* output_state_sf,
float* scaling_factors_scratch, float* recovered_cell_weights,
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
bool* compute_row_sums, bool asymmetric_quantize_inputs,
CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepHybrid");
// Since we have already checked that weights are all there or none, we
// can check the existence of only one to the get the condition.
@ -1106,11 +1151,12 @@ inline void LstmStepHybrid(
// Calculate the input gate. (If not CIFG.)
CalculateLstmGateHybrid(
quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
input_to_input_weights_scale, input_to_input_row_sums,
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale,
aux_input_to_input_row_sums, quantized_output_state_ptr,
output_state_sf, output_state_zp, recurrent_to_input_weights_ptr,
input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
aux_input_zp, aux_input_to_input_weights_ptr,
aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
quantized_output_state_ptr, output_state_sf, output_state_zp,
recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
@ -1122,11 +1168,12 @@ inline void LstmStepHybrid(
// Calculate the forget gate.
CalculateLstmGateHybrid(
quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
input_to_forget_weights_scale, input_to_forget_row_sums,
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale,
aux_input_to_forget_row_sums, quantized_output_state_ptr, output_state_sf,
output_state_zp, recurrent_to_forget_weights_ptr,
input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
aux_input_zp, aux_input_to_forget_weights_ptr,
aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
quantized_output_state_ptr, output_state_sf, output_state_zp,
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
@ -1137,11 +1184,12 @@ inline void LstmStepHybrid(
// Calculate the cell update gate.
CalculateLstmGateHybrid(
quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
input_to_cell_weights_scale, input_to_cell_row_sums,
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale,
aux_input_to_cell_row_sums, quantized_output_state_ptr, output_state_sf,
output_state_zp, recurrent_to_cell_weights_ptr,
input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
aux_input_zp, aux_input_to_cell_weights_ptr,
aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
quantized_output_state_ptr, output_state_sf, output_state_zp,
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
/*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
/*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
@ -1157,11 +1205,12 @@ inline void LstmStepHybrid(
// Calculate the output gate.
CalculateLstmGateHybrid(
quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
input_to_output_weights_scale, input_to_output_row_sums,
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale,
aux_input_to_output_row_sums, quantized_output_state_ptr, output_state_sf,
output_state_zp, recurrent_to_output_weights_ptr,
input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
aux_input_zp, aux_input_to_output_weights_ptr,
aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
quantized_output_state_ptr, output_state_sf, output_state_zp,
recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
@ -1172,11 +1221,11 @@ inline void LstmStepHybrid(
// Update the output state.
CalculateLstmOutputHybrid(
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
params->activation, projection_weights_ptr, projection_weights_scale,
projection_bias_ptr, params->proj_clip, output_state_ptr,
asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums,
context, scratch2, quantized_output_scratch, input_sf, input_zp,
accum_scratch_ptr);
params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
projection_weights_scale, projection_bias_ptr, params->proj_clip,
output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
input_zp, accum_scratch_ptr);
// Copy output state to the output. Note that the output's rows may not be
// contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
@ -1829,13 +1878,21 @@ TfLiteStatus EvalFloat(
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_input_weights_ledger,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_forget_weights_ledger,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_cell_weights_ledger,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* input_to_output_weights_ledger,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_input_weights_ledger,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_forget_weights_ledger,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_cell_weights_ledger,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* recurrent_to_output_weights_ledger,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
@ -1850,9 +1907,11 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
const TfLiteTensor* projection_weights,
const TfLiteTensor* projection_weights_ledger,
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
@ -1929,12 +1988,16 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(output) + t_rel * output_step + output_offset;
LstmStepHybrid(
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
GetTensorData<uint8_t>(input_to_input_weights_ledger),
GetTensorScale(input_to_input_weights),
GetTensorData<int8_t>(input_to_forget_weights),
GetTensorData<uint8_t>(input_to_forget_weights_ledger),
GetTensorScale(input_to_forget_weights),
GetTensorData<int8_t>(input_to_cell_weights),
GetTensorData<uint8_t>(input_to_cell_weights_ledger),
GetTensorScale(input_to_cell_weights),
GetTensorData<int8_t>(input_to_output_weights),
GetTensorData<uint8_t>(input_to_output_weights_ledger),
GetTensorScale(input_to_output_weights), aux_input_ptr,
GetTensorData<int8_t>(aux_input_to_input_weights),
GetTensorScale(aux_input_to_input_weights),
@ -1945,12 +2008,16 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(aux_input_to_output_weights),
GetTensorScale(aux_input_to_output_weights),
GetTensorData<int8_t>(recurrent_to_input_weights),
GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
GetTensorScale(recurrent_to_input_weights),
GetTensorData<int8_t>(recurrent_to_forget_weights),
GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
GetTensorScale(recurrent_to_forget_weights),
GetTensorData<int8_t>(recurrent_to_cell_weights),
GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
GetTensorScale(recurrent_to_cell_weights),
GetTensorData<int8_t>(recurrent_to_output_weights),
GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
GetTensorScale(recurrent_to_output_weights),
GetTensorData<int8_t>(cell_to_input_weights),
GetTensorScale(cell_to_input_weights),
@ -1967,6 +2034,7 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorData<uint8_t>(projection_weights_ledger),
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
@ -2018,12 +2086,16 @@ TfLiteStatus EvalHybrid(
LstmStepHybrid(
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
GetTensorData<uint8_t>(input_to_input_weights_ledger),
GetTensorScale(input_to_input_weights),
GetTensorData<int8_t>(input_to_forget_weights),
GetTensorData<uint8_t>(input_to_forget_weights_ledger),
GetTensorScale(input_to_forget_weights),
GetTensorData<int8_t>(input_to_cell_weights),
GetTensorData<uint8_t>(input_to_cell_weights_ledger),
GetTensorScale(input_to_cell_weights),
GetTensorData<int8_t>(input_to_output_weights),
GetTensorData<uint8_t>(input_to_output_weights_ledger),
GetTensorScale(input_to_output_weights), aux_input_ptr,
GetTensorData<int8_t>(aux_input_to_input_weights),
GetTensorScale(aux_input_to_input_weights),
@ -2034,12 +2106,16 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(aux_input_to_output_weights),
GetTensorScale(aux_input_to_output_weights),
GetTensorData<int8_t>(recurrent_to_input_weights),
GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
GetTensorScale(recurrent_to_input_weights),
GetTensorData<int8_t>(recurrent_to_forget_weights),
GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
GetTensorScale(recurrent_to_forget_weights),
GetTensorData<int8_t>(recurrent_to_cell_weights),
GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
GetTensorScale(recurrent_to_cell_weights),
GetTensorData<int8_t>(recurrent_to_output_weights),
GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
GetTensorScale(recurrent_to_output_weights),
GetTensorData<int8_t>(cell_to_input_weights),
GetTensorScale(cell_to_input_weights),
@ -2056,6 +2132,7 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorData<uint8_t>(projection_weights_ledger),
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params,
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,

View File

@ -125,13 +125,21 @@ TfLiteStatus EvalFloat(
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_input_weights_ledger,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_forget_weights_ledger,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_cell_weights_ledger,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* input_to_output_weights_ledger,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_input_weights_ledger,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_forget_weights_ledger,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_cell_weights_ledger,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* recurrent_to_output_weights_ledger,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
@ -146,9 +154,11 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
const TfLiteTensor* projection_weights,
const TfLiteTensor* projection_weights_ledger,
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
bool forward_sequence, bool time_major, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,

View File

@ -906,14 +906,14 @@ void TestOneHybridAsymmLSTM() {
constexpr float kDefaultScale = 18.0;
ops::builtin::lstm_eval::EvalHybrid(
one_parameter.GetFloatInput(),
HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale),
HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), nullptr,
HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), nullptr,
/*cell_to_input_weights=*/nullptr,
/*cell_to_forget_weights=*/nullptr,
/*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(),
@ -926,7 +926,7 @@ void TestOneHybridAsymmLSTM() {
/*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(),
one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
one_parameter.GetOutputBias(),
HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0),
HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), nullptr,
one_parameter.GetProjectionBias(), &param,
/*forward_sequence=*/true,
/*time_major=*/true,

View File

@ -2114,6 +2114,620 @@ TEST(LstmOpTest, InvalidTypes) {
}
#endif
class HybridSparseLSTMOpModel : public ::tflite::SingleOpModel {
public:
HybridSparseLSTMOpModel(
int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
bool use_peephole, bool use_projection_weights, bool use_projection_bias,
float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const TensorData& input_weights_td,
const std::vector<float>& input_to_input_weights,
const std::vector<float>& input_to_forget_weights,
const std::vector<float>& input_to_cell_weights,
const std::vector<float>& input_to_output_weights,
const TensorData& recurrent_weights_td,
const std::vector<float>& recurrent_to_input_weights,
const std::vector<float>& recurrent_to_forget_weights,
const std::vector<float>& recurrent_to_cell_weights,
const std::vector<float>& recurrent_to_output_weights,
const ::tflite::TensorType& weight_type = ::tflite::TensorType_INT8)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
n_output_(n_output) {
input_ = AddInput(::tflite::TensorType_FLOAT32);
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
input_to_input_weights_ =
AddConstSparseInput(input_weights_td, input_to_input_weights, true);
}
input_to_forget_weights_ =
AddConstSparseInput(input_weights_td, input_to_forget_weights, true);
input_to_cell_weights_ =
AddConstSparseInput(input_weights_td, input_to_cell_weights, true);
input_to_output_weights_ =
AddConstSparseInput(input_weights_td, input_to_output_weights, true);
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
recurrent_to_input_weights_ = AddConstSparseInput(
recurrent_weights_td, recurrent_to_input_weights, true);
}
recurrent_to_forget_weights_ = AddConstSparseInput(
recurrent_weights_td, recurrent_to_forget_weights, true);
recurrent_to_cell_weights_ = AddConstSparseInput(
recurrent_weights_td, recurrent_to_cell_weights, true);
recurrent_to_output_weights_ = AddConstSparseInput(
recurrent_weights_td, recurrent_to_output_weights, true);
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
cell_to_input_weights_ = AddInput(weight_type);
}
cell_to_forget_weights_ = AddInput(weight_type);
cell_to_output_weights_ = AddInput(weight_type);
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
cell_to_output_weights_ = AddNullInput();
}
if (use_cifg) {
input_gate_bias_ = AddNullInput();
} else {
input_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
}
forget_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
cell_bias_ = AddInput(::tflite::TensorType_FLOAT32);
output_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
if (use_projection_weights) {
projection_weights_ = AddInput(weight_type);
if (use_projection_bias) {
projection_bias_ = AddInput(::tflite::TensorType_FLOAT32);
} else {
projection_bias_ = AddNullInput();
}
} else {
projection_weights_ = AddNullInput();
projection_bias_ = AddNullInput();
}
// Adding the 2 state tensors.
output_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32,
{n_output_ * n_batch_}},
true);
cell_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32,
{n_cell_ * n_batch_}},
true);
if (use_cifg) {
input_layer_norm_weights_ = AddNullInput();
} else {
input_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
}
forget_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
cell_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
output_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
output_ = AddOutput(::tflite::TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
proj_clip, LSTMKernelType_FULL, false)
.Union());
BuildInterpreter(input_shapes);
}
void SetCellToInputWeights(std::vector<float> f) {
SignedSymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
}
void SetCellToForgetWeights(std::vector<float> f) {
SignedSymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
}
void SetCellToOutputWeights(std::vector<float> f) {
SignedSymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
}
void SetInputLayerNormWeights(std::vector<float> f) {
PopulateTensor(input_layer_norm_weights_, f);
}
void SetForgetLayerNormWeights(std::vector<float> f) {
PopulateTensor(forget_layer_norm_weights_, f);
}
void SetCellLayerNormWeights(std::vector<float> f) {
PopulateTensor(cell_layer_norm_weights_, f);
}
void SetOutputLayerNormWeights(std::vector<float> f) {
PopulateTensor(output_layer_norm_weights_, f);
}
void SetInputGateBias(std::vector<float> f) {
PopulateTensor(input_gate_bias_, f);
}
void SetForgetGateBias(std::vector<float> f) {
PopulateTensor(forget_gate_bias_, f);
}
void SetCellBias(std::vector<float> f) { PopulateTensor(cell_bias_, f); }
void SetOutputGateBias(std::vector<float> f) {
PopulateTensor(output_gate_bias_, f);
}
void SetProjectionWeights(std::vector<float> f) {
SignedSymmetricQuantizeAndPopulate(projection_weights_, f);
}
void SetProjectionBias(std::vector<float> f) {
PopulateTensor(projection_bias_, f);
}
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int num_inputs() { return n_input_; }
int num_outputs() { return n_output_; }
int num_cells() { return n_cell_; }
int num_batches() { return n_batch_; }
protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
int input_to_cell_weights_;
int input_to_output_weights_;
int recurrent_to_input_weights_;
int recurrent_to_forget_weights_;
int recurrent_to_cell_weights_;
int recurrent_to_output_weights_;
int cell_to_input_weights_;
int cell_to_forget_weights_;
int cell_to_output_weights_;
int input_layer_norm_weights_;
int forget_layer_norm_weights_;
int cell_layer_norm_weights_;
int output_layer_norm_weights_;
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int output_gate_bias_;
int projection_weights_;
int projection_bias_;
int output_state_;
int cell_state_;
int output_;
int n_batch_;
int n_input_;
int n_cell_;
int n_output_;
};
class BaseSparseLstmTest : public ::testing::Test {
protected:
// Weights of the Sparse Layer Norm LSTM model. Some are optional.
std::vector<float> input_to_input_weights_;
std::vector<float> input_to_cell_weights_;
std::vector<float> input_to_forget_weights_;
std::vector<float> input_to_output_weights_;
std::vector<float> input_gate_bias_;
std::vector<float> cell_gate_bias_;
std::vector<float> forget_gate_bias_;
std::vector<float> output_gate_bias_;
std::vector<float> recurrent_to_input_weights_;
std::vector<float> recurrent_to_cell_weights_;
std::vector<float> recurrent_to_forget_weights_;
std::vector<float> recurrent_to_output_weights_;
std::vector<float> cell_to_input_weights_;
std::vector<float> cell_to_forget_weights_;
std::vector<float> cell_to_output_weights_;
std::vector<float> input_layer_norm_weights_;
std::vector<float> forget_layer_norm_weights_;
std::vector<float> cell_layer_norm_weights_;
std::vector<float> output_layer_norm_weights_;
std::vector<float> projection_weights_;
std::vector<int> input_to_input_weights_size_;
std::vector<int> input_to_cell_weights_size_;
std::vector<int> input_to_forget_weights_size_;
std::vector<int> input_to_output_weights_size_;
std::vector<int> recurrent_to_input_weights_size_;
std::vector<int> recurrent_to_cell_weights_size_;
std::vector<int> recurrent_to_forget_weights_size_;
std::vector<int> recurrent_to_output_weights_size_;
int n_batch_;
int n_input_;
int n_cell_;
int n_output_;
float cell_clip_;
float proj_clip_;
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
std::vector<std::vector<float>> sparse_layer_norm_lstm_input_;
// Compares output up to tolerance to the result of the layer_norm_lstm given
// the input.
void VerifyGoldens(const std::vector<std::vector<float>>& input,
const std::vector<std::vector<float>>& output,
HybridSparseLSTMOpModel* sparse_layer_norm_lstm,
float tolerance = 1e-5) {
const int num_batches = input.size();
EXPECT_GT(num_batches, 0);
const int num_inputs = sparse_layer_norm_lstm->num_inputs();
EXPECT_GT(num_inputs, 0);
const int input_sequence_size = input[0].size() / num_inputs;
EXPECT_GT(input_sequence_size, 0);
for (int i = 0; i < input_sequence_size; ++i) {
for (int b = 0; b < num_batches; ++b) {
const float* batch_start = input[b].data() + i * num_inputs;
const float* batch_end = batch_start + num_inputs;
sparse_layer_norm_lstm->SetInput(
b * sparse_layer_norm_lstm->num_inputs(), batch_start, batch_end);
}
sparse_layer_norm_lstm->Invoke();
const int num_outputs = sparse_layer_norm_lstm->num_outputs();
std::vector<float> expected;
for (int b = 0; b < num_batches; ++b) {
const float* golden_start_batch = output[b].data() + i * num_outputs;
const float* golden_end_batch = golden_start_batch + num_outputs;
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
}
EXPECT_THAT(
sparse_layer_norm_lstm->GetOutput(),
ElementsAreArray(::tflite::ArrayFloatNear(expected, tolerance)));
}
}
};
class NoCifgPeepholeProjectionNoClippingSparseLstmTest
: public BaseSparseLstmTest {
void SetUp() override {
n_batch_ = 2;
n_input_ = 48;
n_cell_ = 4;
n_output_ = 16;
cell_clip_ = 0.0;
proj_clip_ = 0.0;
/* clang-format off */
input_to_input_weights_ = {
/* 1st row */
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
/* 2nd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 3rd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 4th row */
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
input_to_input_weights_size_ = {4, 48};
input_to_forget_weights_ = {
/* 1st row */
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
/* 2nd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 3rd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 4th row */
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
input_to_forget_weights_size_ = {4, 48};
input_to_cell_weights_ = {
/* 1st row */
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
/* 2nd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 3rd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 4th row */
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
input_to_cell_weights_size_ = {4, 48};
input_to_output_weights_ = {
/* 1st row */
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
/* 2nd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 3rd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 4th row */
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
input_to_output_weights_size_ = {4, 48};
input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
recurrent_to_input_weights_ = {
-0.2, -0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 1st row
0.1, -0.5, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 2nd row
-0.2, -0.3, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 3rd row
0.05, -0.2, -0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 4th row
};
recurrent_to_input_weights_size_ = {4, 16};
recurrent_to_cell_weights_ = {
-0.3, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 1st row
-0.3, 0.8, -0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 2nd row
-0.2, 0.3, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 3rd row
-0.6, -0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 4th row
};
recurrent_to_cell_weights_size_ = {4, 16};
recurrent_to_forget_weights_ = {
-0.5, -0.3, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 1st row
-0.2, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 2nd row
0.9, 0.3, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 3rd row
0.2, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 4th row
};
recurrent_to_forget_weights_size_ = {4, 16};
recurrent_to_output_weights_ = {
0.3, -0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 1st row
-0.2, -0.5, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 2nd row
-0.2, -0.6, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 3rd row
-0.4, -0.7, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, // 4th row
};
recurrent_to_output_weights_size_ = {4, 16};
cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
projection_weights_ = {
-0.1, 0.2, 0.01, -0.2, // 1st row
0.1, 0.5, 0.3, 0.08, // 2nd row
0.07, 0.2, -0.4, 0.2, // 3rd row
0.0, 0.0, 0.0, 0.0, // 4th row
0.0, 0.0, 0.0, 0.0, // 5th row
0.0, 0.0, 0.0, 0.0, // 6th row
0.0, 0.0, 0.0, 0.0, // 7th row
0.0, 0.0, 0.0, 0.0, // 8th row
0.0, 0.0, 0.0, 0.0, // 9th row
0.0, 0.0, 0.0, 0.0, // 10th row
0.0, 0.0, 0.0, 0.0, // 11th row
0.0, 0.0, 0.0, 0.0, // 12th row
0.0, 0.0, 0.0, 0.0, // 13th row
0.0, 0.0, 0.0, 0.0, // 14th row
0.0, 0.0, 0.0, 0.0, // 15th row
0.0, 0.0, 0.0, 0.0, // 16th row
0.0, 0.0, 0.0, 0.0, // 17th row
0.0, 0.0, 0.0, 0.0, // 18th row
};
sparse_layer_norm_lstm_input_ = {
// Batch0: 2 (input_sequence_size) * 45 (n_input_)
{
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
-1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // seq 1
},
// Batch1: 2 (input_sequence_size) * 45 (n_input_)
{
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
-1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
1.0, -2.5, 0.7, -1.9, 0.2, -1.0, 1.0, -1.0, // seq 1
},
};
/* clang-format on */
}
};
TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,
HybridSparseLstmBlackBoxTest) {
TensorData input_weight = {};
input_weight.type = TensorType_FLOAT32;
input_weight.shape = {4, 48};
input_weight.traversal_order = {0, 1, 2};
input_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
input_weight.block_map = {1};
input_weight.block_size = {16};
TensorData recurrent_weight = {};
recurrent_weight.type = TensorType_FLOAT32;
recurrent_weight.shape = {4, 16};
recurrent_weight.traversal_order = {0, 1, 2};
recurrent_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
recurrent_weight.block_map = {1};
recurrent_weight.block_size = {16};
HybridSparseLSTMOpModel sparse_layer_norm_lstm(
n_batch_, n_input_, n_cell_, n_output_,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false, cell_clip_, proj_clip_,
{
{n_batch_, n_input_}, // input tensor
{input_to_input_weights_size_},
{input_to_forget_weights_size_},
{input_to_cell_weights_size_},
{input_to_output_weights_size_},
{recurrent_to_input_weights_size_},
{recurrent_to_forget_weights_size_},
{recurrent_to_cell_weights_size_},
{recurrent_to_output_weights_size_},
{n_cell_}, // cell_to_input_weight tensor
{n_cell_}, // cell_to_forget_weight tensor
{n_cell_}, // cell_to_output_weight tensor
{n_cell_}, // input_gate_bias tensor
{n_cell_}, // forget_gate_bias tensor
{n_cell_}, // cell_bias tensor
{n_cell_}, // output_gate_bias tensor
{n_output_, n_cell_}, // projection_weight tensor
{0}, // projection_bias tensor
{n_output_ * n_batch_}, // output_state tensor
{n_cell_ * n_batch_}, // cell_state tensor
{n_cell_}, // input_layer_norm_weight tensor
{n_cell_}, // forget_layer_norm_weight tensor
{n_cell_}, // cell_layer_norm_weight tensor
{n_cell_}, // output_layer_norm_weight tensor
},
input_weight, input_to_input_weights_, input_to_forget_weights_,
input_to_cell_weights_, input_to_output_weights_, recurrent_weight,
recurrent_to_input_weights_, recurrent_to_forget_weights_,
recurrent_to_cell_weights_, recurrent_to_output_weights_);
sparse_layer_norm_lstm.SetInputGateBias(input_gate_bias_);
sparse_layer_norm_lstm.SetCellBias(cell_gate_bias_);
sparse_layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
sparse_layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
sparse_layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
sparse_layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
sparse_layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
sparse_layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
sparse_layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
sparse_layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
sparse_layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
sparse_layer_norm_lstm.SetProjectionWeights(projection_weights_);
/* clang-format off */
const std::vector<std::vector<float>> sparse_layer_norm_lstm_golden_output = {
{
// Batch0: 2 (input_sequence_size) * 3 (n_output_)
0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0,
0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0,
},
{
// Batch1: 3 (input_sequence_size) * 3 (n_output_)
0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0,
0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0,
}};
/* clang-format on */
VerifyGoldens(sparse_layer_norm_lstm_input_,
sparse_layer_norm_lstm_golden_output, &sparse_layer_norm_lstm);
}
// Test parameter controls asymmetric_quantize_inputs in LSTMOpModel.
INSTANTIATE_TEST_SUITE_P(
Parameterized, LstmOpTest,

View File

@ -214,9 +214,82 @@ class SingleOpModel {
return AddConstInput(TensorData{type, shape}, data);
}
// TODO(b/166202747): Use a better way to do type specialization. Reduce
// duplicate code in the two functions below.
int AddConstSparseInput(const TensorData& t,
const std::vector<int8_t>& data) {
int id = tensors_.size();
const int dims_count = t.traversal_order.size();
std::vector<int8_t> dense_data(data);
tflite::optimize::sparsity::FormatConverter<int8_t> converter(
t.shape, t.traversal_order, t.format, t.block_size, t.block_map);
converter.DenseToSparse(dense_data.data());
const auto dim_metadata = converter.GetDimMetadata();
const auto sparse_data = converter.GetData();
// Build sparsity parameter.
std::vector<flatbuffers::Offset<DimensionMetadata>> fb_dim_metadata(
dims_count);
for (int i = 0; i < dims_count; i++) {
const int metadata_idx = 2 * i;
if (i < t.shape.size() &&
t.format[t.traversal_order[i]] == kTfLiteDimSparseCSR) {
auto array_segments =
CreateInt32Vector(builder_,
builder_.CreateVector(dim_metadata[metadata_idx]))
.Union();
auto array_indices =
CreateInt32Vector(
builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1]))
.Union();
fb_dim_metadata[i] = CreateDimensionMetadata(
builder_, DimensionType_SPARSE_CSR, 0,
SparseIndexVector_Int32Vector, array_segments,
SparseIndexVector_Int32Vector, array_indices);
} else {
fb_dim_metadata[i] = CreateDimensionMetadata(
builder_, DimensionType_DENSE, dim_metadata[metadata_idx][0]);
}
}
flatbuffers::Offset<SparsityParameters> s_param = CreateSparsityParameters(
builder_, builder_.CreateVector(t.traversal_order),
builder_.CreateVector(t.block_map),
builder_.CreateVector(fb_dim_metadata));
int buffer_id = 0;
if (!data.empty()) {
// Initialize buffers list with empty buffer to allow for non-const
// tensors.
if (buffers_.empty()) {
buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
}
// Add compressed data as a Buffer to buffers list.
buffer_id = buffers_.size();
auto data_buffer = builder_.CreateVector(
reinterpret_cast<const uint8_t*>(sparse_data.data()),
sparse_data.size());
buffers_.push_back(CreateBuffer(builder_, data_buffer));
}
tensors_.push_back(CreateTensor(
builder_, builder_.CreateVector<int>(t.shape), t.type,
/*buffer=*/buffer_id,
/*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param));
inputs_.push_back(id);
tensor_data_[id] = t;
return id;
}
// Add a constant sparse tensor as input.
template <typename T>
int AddConstSparseInput(const TensorData& t, std::initializer_list<T> data) {
int AddConstSparseInput(const TensorData& t, const std::vector<T>& data,
bool symmetric_quantize = false) {
int id = tensors_.size();
const int dims_count = t.traversal_order.size();
std::vector<T> dense_data(data);
@ -258,8 +331,9 @@ class SingleOpModel {
builder_.CreateVector(t.block_map),
builder_.CreateVector(fb_dim_metadata));
flatbuffers::Offset<QuantizationParameters> q_params = 0;
int buffer_id = 0;
if (data.size()) {
if (!data.empty()) {
// Initialize buffers list with empty buffer to allow for non-const
// tensors.
if (buffers_.empty()) {
@ -268,16 +342,31 @@ class SingleOpModel {
// Add compressed data as a Buffer to buffers list.
buffer_id = buffers_.size();
auto data_buffer = builder_.CreateVector(
reinterpret_cast<const uint8_t*>(sparse_data.data()),
sizeof(T) * sparse_data.size());
buffers_.push_back(CreateBuffer(builder_, data_buffer));
if (symmetric_quantize) {
const int length = sparse_data.size();
std::vector<int8_t> q(length);
float min, max, scaling_factor;
tensor_utils::SymmetricQuantizeFloats(
sparse_data.data(), length, q.data(), &min, &max, &scaling_factor);
q_params = CreateQuantizationParameters(
builder_, 0, 0, builder_.CreateVector<float>({scaling_factor}),
builder_.CreateVector<int64_t>({0}));
auto data_buffer = builder_.CreateVector(
reinterpret_cast<const uint8_t*>(q.data()), q.size());
buffers_.push_back(CreateBuffer(builder_, data_buffer));
} else {
auto data_buffer = builder_.CreateVector(
reinterpret_cast<const uint8_t*>(sparse_data.data()),
sizeof(T) * sparse_data.size());
buffers_.push_back(CreateBuffer(builder_, data_buffer));
}
}
tensors_.push_back(CreateTensor(
builder_, builder_.CreateVector<int>(t.shape), t.type,
/*buffer=*/buffer_id,
/*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param));
tensors_.push_back(
CreateTensor(builder_, builder_.CreateVector<int>(t.shape),
symmetric_quantize ? TensorType_INT8 : t.type,
/*buffer=*/buffer_id,
/*name=*/0, q_params, /*is_variable=*/false, s_param));
inputs_.push_back(id);
tensor_data_[id] = t;

View File

@ -650,11 +650,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
const int row_sums_size = row_sums->dims->data[0];
return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
input, input_to_input_weights,
/*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
/*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
/*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
/*input_to_output_weights_ledger*/ nullptr,
recurrent_to_input_weights,
/*recurrent_to_input_weights_ledger*/ nullptr,
recurrent_to_forget_weights,
/*recurrent_to_forget_weights_ledger*/ nullptr,
recurrent_to_cell_weights,
/*recurrent_to_cell_weights_ledger*/ nullptr,
recurrent_to_output_weights,
/*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights,
cell_to_forget_weights, cell_to_output_weights,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*aux_input=*/nullptr,
@ -663,7 +672,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, &lstm_params,
projection_weights, /*projection_weights_ledger*/ nullptr,
projection_bias, &lstm_params,
/*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer,
GetTemporary(context, node, kInputScalingFactors),