Add builtin sparse LSTM kernel.
PiperOrigin-RevId: 329562447 Change-Id: I5c407b513fbc86d21f6ea2d626da7b69dcd38bc7
This commit is contained in:
parent
b4ee2c4294
commit
8744e4b2b9
@ -1136,10 +1136,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
||||
const int bw_row_sums_size = bw_row_sums->dims->data[0];
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||
fw_input_to_cell_weights, fw_input_to_output_weights,
|
||||
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
|
||||
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
|
||||
input, fw_input_to_input_weights,
|
||||
/*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
|
||||
/*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
|
||||
/*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
|
||||
/*input_to_output_weights_ledger*/ nullptr,
|
||||
fw_recurrent_to_input_weights,
|
||||
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||
fw_recurrent_to_forget_weights,
|
||||
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||
fw_recurrent_to_cell_weights,
|
||||
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||
fw_recurrent_to_output_weights,
|
||||
/*recurrent_to_output_weights_ledger*/ nullptr,
|
||||
fw_cell_to_input_weights, fw_cell_to_forget_weights,
|
||||
fw_cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
@ -1149,7 +1158,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
||||
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
||||
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
|
||||
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
|
||||
fw_output_gate_bias, fw_projection_weights,
|
||||
/*projection_weights_ledger*/ nullptr, fw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||
fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
||||
@ -1167,10 +1177,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
||||
bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input_to_cell_weights, bw_input_to_output_weights,
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||
bw_input, bw_input_to_input_weights,
|
||||
/*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
|
||||
/*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
|
||||
/*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
|
||||
/*input_to_output_weights_ledger*/ nullptr,
|
||||
bw_recurrent_to_input_weights,
|
||||
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||
bw_recurrent_to_forget_weights,
|
||||
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||
bw_recurrent_to_cell_weights,
|
||||
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||
bw_recurrent_to_output_weights,
|
||||
/*recurrent_to_output_weights_ledger*/ nullptr,
|
||||
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
||||
bw_cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
@ -1180,7 +1199,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
||||
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
||||
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
|
||||
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
|
||||
bw_output_gate_bias, bw_projection_weights,
|
||||
/*projection_weights_ledger*/ nullptr, bw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
||||
bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
||||
|
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
@ -42,7 +41,7 @@ using ::testing::ElementsAreArray;
|
||||
template <typename T>
|
||||
class DensifyOpModel : public SingleOpModel {
|
||||
public:
|
||||
DensifyOpModel(const TensorData& input, std::initializer_list<T> input_data,
|
||||
DensifyOpModel(const TensorData& input, const std::vector<T>& input_data,
|
||||
int version = 1) {
|
||||
input_ = AddConstSparseInput(input, input_data);
|
||||
output_ = AddOutput({input.type, input.shape});
|
||||
@ -65,9 +64,8 @@ class DensifyOpModel : public SingleOpModel {
|
||||
};
|
||||
|
||||
TEST(DensifyOpTest, Float) {
|
||||
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
|
||||
0, 0, 5, 0, 0, 7};
|
||||
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
|
||||
std::vector<float> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||
std::vector<float> sparse_values = {6, 9, 8, 5, 7};
|
||||
TensorData input = {};
|
||||
input.type = TensorType_FLOAT32;
|
||||
input.shape = {3, 4};
|
||||
@ -80,9 +78,8 @@ TEST(DensifyOpTest, Float) {
|
||||
}
|
||||
|
||||
TEST(DensifyOpTest, Float3D) {
|
||||
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
|
||||
0, 0, 5, 0, 0, 7};
|
||||
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
|
||||
std::vector<float> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||
std::vector<float> sparse_values = {6, 9, 8, 5, 7};
|
||||
TensorData input = {};
|
||||
input.type = TensorType_FLOAT32;
|
||||
input.shape = {3, 2, 2};
|
||||
@ -95,9 +92,8 @@ TEST(DensifyOpTest, Float3D) {
|
||||
}
|
||||
|
||||
TEST(DensifyOpTest, Int8) {
|
||||
std::initializer_list<int8_t> dense_values = {6, 0, 9, 8, 0, 0,
|
||||
0, 0, 5, 0, 0, 7};
|
||||
std::initializer_list<int8_t> sparse_values = {6, 9, 8, 5, 7};
|
||||
std::vector<int8_t> dense_values = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||
std::vector<int8_t> sparse_values = {6, 9, 8, 5, 7};
|
||||
TensorData input = {};
|
||||
input.type = TensorType_INT8;
|
||||
input.shape = {3, 4};
|
||||
|
@ -1144,7 +1144,7 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
|
||||
SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
|
||||
int batches, const TensorData& input,
|
||||
const TensorData& weights,
|
||||
std::initializer_list<T> weights_data,
|
||||
const std::vector<T>& weights_data,
|
||||
int num_threads = 1)
|
||||
: batches_(batches), units_(units) {
|
||||
int total_input_size = 1;
|
||||
|
@ -55,6 +55,10 @@ struct OpData {
|
||||
int scratch_tensor_index;
|
||||
lstm_eval::IntegerLstmParameter integer_lstm_param;
|
||||
bool compute_row_sums;
|
||||
|
||||
// Only used for sparse hybrid lstm kernels.
|
||||
int ledger_index;
|
||||
bool ledger_initialized;
|
||||
};
|
||||
|
||||
namespace full {
|
||||
@ -77,6 +81,63 @@ enum HybridTemporaryTensor {
|
||||
kNumHybridTemporaryTensors = 12,
|
||||
};
|
||||
|
||||
constexpr int kLedgersToAdd = 9;
|
||||
constexpr int kInputToInputWeightsLedgerOffset = 0;
|
||||
constexpr int kInputToForgetWeightsLedgerOffset = 1;
|
||||
constexpr int kInputToCellWeightsLedgerOffset = 2;
|
||||
constexpr int kInputToOutputWeightsLedgerOffset = 3;
|
||||
constexpr int kRecurrentToInputWeightsLedgerOffset = 4;
|
||||
constexpr int kRecurrentToForgetWeightsLedgerOffset = 5;
|
||||
constexpr int kRecurrentToCellWeightsLedgerOffset = 6;
|
||||
constexpr int kRecurrentToOutputWeightsLedgerOffset = 7;
|
||||
constexpr int kProjectionWeightsLedgerOffset = 8;
|
||||
|
||||
TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context,
|
||||
TfLiteTensor* ledger) {
|
||||
ledger->type = kTfLiteUInt8;
|
||||
ledger->allocation_type = kTfLiteArenaRwPersistent;
|
||||
if (sparsity == nullptr) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
|
||||
ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
|
||||
sparsity->dim_metadata[1].array_segments->size - 1;
|
||||
return context->ResizeTensor(context, ledger, ledger_size);
|
||||
}
|
||||
|
||||
TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) {
|
||||
if (sparsity == nullptr) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
const auto* array_segments = sparsity->dim_metadata[1].array_segments;
|
||||
const auto* array_indices = sparsity->dim_metadata[1].array_indices;
|
||||
uint8_t* output_data = GetTensorData<uint8_t>(ledger);
|
||||
int output_data_ptr = 0;
|
||||
|
||||
for (int i = 0; i < array_segments->size - 1; i++) {
|
||||
int row_start = array_segments->data[i];
|
||||
int row_end = array_segments->data[i + 1];
|
||||
if (row_end - row_start > UINT8_MAX) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
// Copy num of non-zero blocks in row i.
|
||||
output_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
|
||||
output_data_ptr++;
|
||||
|
||||
for (int j = row_start; j < row_end; j++) {
|
||||
if (array_indices->data[j] > UINT8_MAX) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
// Copy indices of non-zero blocks in row i.
|
||||
output_data[output_data_ptr] =
|
||||
static_cast<uint8_t>(array_indices->data[j]);
|
||||
output_data_ptr++;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
TfLiteContext* context, TfLiteNode* node,
|
||||
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
|
||||
@ -744,6 +805,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// TODO(b/159066113): maybe just add the minimum required temp tensors?
|
||||
context->AddTensors(context, kNumHybridTemporaryTensors,
|
||||
&op_data->scratch_tensor_index);
|
||||
// Tensors used for the sparse hybrid kernel.
|
||||
context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd,
|
||||
&op_data->ledger_index);
|
||||
return op_data;
|
||||
}
|
||||
|
||||
@ -1239,6 +1303,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
|
||||
|
||||
const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr);
|
||||
|
||||
// The type of Integer LSTM.
|
||||
const int num_intermediate_tensors = node->intermediates->size;
|
||||
if (is_integer) {
|
||||
@ -1251,7 +1317,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
if (is_hybrid_op) {
|
||||
node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
|
||||
if (is_sparse_op) {
|
||||
node->temporaries =
|
||||
TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd);
|
||||
} else {
|
||||
node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
|
||||
}
|
||||
} else if (is_integer) {
|
||||
if (is_8x8_16) {
|
||||
node->temporaries = TfLiteIntArrayCreate(6);
|
||||
@ -1289,7 +1360,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
if (is_hybrid_op) {
|
||||
op_data->compute_row_sums = true;
|
||||
if (!is_sparse_op) {
|
||||
op_data->compute_row_sums = true;
|
||||
}
|
||||
// Allocate temporary tensors to store quantized values of input,
|
||||
// output_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
@ -1454,6 +1527,125 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||
}
|
||||
|
||||
if (is_sparse_op) {
|
||||
op_data->ledger_initialized = false;
|
||||
int offset = kNumHybridTemporaryTensors;
|
||||
{
|
||||
node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kInputToInputWeightsLedgerOffset;
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
TfLiteTensor* input_to_input_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToInputWeightsLedgerOffset];
|
||||
auto status = make_ledger(input_to_input_weights == nullptr
|
||||
? nullptr
|
||||
: input_to_input_weights->sparsity,
|
||||
context, input_to_input_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kInputToForgetWeightsLedgerOffset;
|
||||
const TfLiteTensor* input_to_forget_weights =
|
||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
||||
TfLiteTensor* input_to_forget_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToForgetWeightsLedgerOffset];
|
||||
auto status = make_ledger(input_to_forget_weights->sparsity, context,
|
||||
input_to_forget_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kInputToCellWeightsLedgerOffset;
|
||||
const TfLiteTensor* input_to_cell_weights =
|
||||
GetInput(context, node, kInputToCellWeightsTensor);
|
||||
TfLiteTensor* input_to_cell_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToCellWeightsLedgerOffset];
|
||||
auto status = make_ledger(input_to_cell_weights->sparsity, context,
|
||||
input_to_cell_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kInputToOutputWeightsLedgerOffset;
|
||||
const TfLiteTensor* input_to_output_weights =
|
||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
||||
TfLiteTensor* input_to_output_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToOutputWeightsLedgerOffset];
|
||||
auto status = make_ledger(input_to_output_weights->sparsity, context,
|
||||
input_to_output_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset;
|
||||
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
|
||||
context, node, kRecurrentToInputWeightsTensor);
|
||||
TfLiteTensor* recurrent_to_input_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToInputWeightsLedgerOffset];
|
||||
auto status = make_ledger(recurrent_to_input_weights == nullptr
|
||||
? nullptr
|
||||
: recurrent_to_input_weights->sparsity,
|
||||
context, recurrent_to_input_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries
|
||||
->data[offset + kRecurrentToForgetWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset;
|
||||
const TfLiteTensor* recurrent_to_forget_weights =
|
||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
||||
TfLiteTensor* recurrent_to_forget_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToForgetWeightsLedgerOffset];
|
||||
auto status = make_ledger(recurrent_to_forget_weights->sparsity,
|
||||
context, recurrent_to_forget_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset;
|
||||
const TfLiteTensor* recurrent_to_cell_weights =
|
||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
||||
TfLiteTensor* recurrent_to_cell_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToCellWeightsLedgerOffset];
|
||||
auto status = make_ledger(recurrent_to_cell_weights->sparsity, context,
|
||||
recurrent_to_cell_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries
|
||||
->data[offset + kRecurrentToOutputWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset;
|
||||
const TfLiteTensor* recurrent_to_output_weights =
|
||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
||||
TfLiteTensor* recurrent_to_output_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToOutputWeightsLedgerOffset];
|
||||
auto status = make_ledger(recurrent_to_output_weights->sparsity,
|
||||
context, recurrent_to_output_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
{
|
||||
node->temporaries->data[offset + kProjectionWeightsLedgerOffset] =
|
||||
op_data->ledger_index + kProjectionWeightsLedgerOffset;
|
||||
const TfLiteTensor* projection_weights =
|
||||
GetInput(context, node, kProjectionWeightsTensor);
|
||||
TfLiteTensor* projection_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kProjectionWeightsLedgerOffset];
|
||||
auto status = make_ledger(projection_weights->sparsity, context,
|
||||
projection_weights_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_integer) {
|
||||
@ -1624,14 +1816,116 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
const bool is_hybrid = (input->type == kTfLiteFloat32);
|
||||
const bool is_sparse = input_to_output_weights->sparsity != nullptr;
|
||||
if (is_hybrid) {
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
const int row_sums_size = row_sums->dims->data[0];
|
||||
if (is_sparse) {
|
||||
TfLiteTensor* input_to_input_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToInputWeightsLedgerOffset];
|
||||
TfLiteTensor* input_to_forget_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToForgetWeightsLedgerOffset];
|
||||
TfLiteTensor* input_to_cell_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToCellWeightsLedgerOffset];
|
||||
TfLiteTensor* input_to_output_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kInputToOutputWeightsLedgerOffset];
|
||||
TfLiteTensor* recurrent_to_input_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToInputWeightsLedgerOffset];
|
||||
TfLiteTensor* recurrent_to_forget_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToForgetWeightsLedgerOffset];
|
||||
TfLiteTensor* recurrent_to_cell_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToCellWeightsLedgerOffset];
|
||||
TfLiteTensor* recurrent_to_output_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kRecurrentToOutputWeightsLedgerOffset];
|
||||
TfLiteTensor* projection_weights_ledger =
|
||||
&context->tensors[op_data->ledger_index +
|
||||
kProjectionWeightsLedgerOffset];
|
||||
if (!op_data->ledger_initialized) {
|
||||
copy_ledger(input_to_input_weights == nullptr
|
||||
? nullptr
|
||||
: input_to_input_weights->sparsity,
|
||||
input_to_input_weights_ledger);
|
||||
copy_ledger(input_to_forget_weights->sparsity,
|
||||
input_to_forget_weights_ledger);
|
||||
copy_ledger(input_to_cell_weights->sparsity,
|
||||
input_to_cell_weights_ledger);
|
||||
copy_ledger(input_to_output_weights->sparsity,
|
||||
input_to_output_weights_ledger);
|
||||
copy_ledger(recurrent_to_input_weights == nullptr
|
||||
? nullptr
|
||||
: recurrent_to_input_weights->sparsity,
|
||||
recurrent_to_input_weights_ledger);
|
||||
copy_ledger(recurrent_to_forget_weights->sparsity,
|
||||
recurrent_to_forget_weights_ledger);
|
||||
copy_ledger(recurrent_to_cell_weights->sparsity,
|
||||
recurrent_to_cell_weights_ledger);
|
||||
copy_ledger(recurrent_to_output_weights->sparsity,
|
||||
recurrent_to_output_weights_ledger);
|
||||
copy_ledger(projection_weights->sparsity,
|
||||
projection_weights_ledger);
|
||||
op_data->ledger_initialized = true;
|
||||
}
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_input_weights_ledger,
|
||||
input_to_forget_weights, input_to_forget_weights_ledger,
|
||||
input_to_cell_weights, input_to_cell_weights_ledger,
|
||||
input_to_output_weights, input_to_output_weights_ledger,
|
||||
recurrent_to_input_weights, recurrent_to_input_weights_ledger,
|
||||
recurrent_to_forget_weights, recurrent_to_forget_weights_ledger,
|
||||
recurrent_to_cell_weights, recurrent_to_cell_weights_ledger,
|
||||
recurrent_to_output_weights, recurrent_to_output_weights_ledger,
|
||||
cell_to_input_weights, cell_to_forget_weights,
|
||||
cell_to_output_weights, input_layer_norm_coefficients,
|
||||
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
||||
output_layer_norm_coefficients,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||
projection_weights, projection_weights_ledger, projection_bias,
|
||||
params,
|
||||
/*forward_sequence=*/true, /*time_major=*/true,
|
||||
/*output_offset=*/0, GetTemporary(context, node, kScratchBuffer),
|
||||
GetTemporary(context, node, kInputScalingFactors),
|
||||
/*aux_input_sf=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateScalingFactors),
|
||||
GetTemporary(context, node, kProductScalingFactors),
|
||||
GetTemporary(context, node, kRecoveredCellWeights),
|
||||
GetTemporary(context, node, kInputQuantized),
|
||||
/*aux_input_quantized=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateQuantized),
|
||||
GetTemporary(context, node, kCellStateQuantized), output_state,
|
||||
cell_state, GetTemporary(context, node, kAccumScratch), output,
|
||||
GetTemporary(context, node, kInputZeroPoints),
|
||||
/*aux_input_zp=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
|
||||
row_sums_size, &op_data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
input, input_to_input_weights,
|
||||
/*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
|
||||
/*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
|
||||
/*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
|
||||
/*input_to_output_weights_ledger*/ nullptr,
|
||||
recurrent_to_input_weights,
|
||||
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||
recurrent_to_forget_weights,
|
||||
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||
recurrent_to_cell_weights,
|
||||
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||
recurrent_to_output_weights,
|
||||
/*recurrent_to_output_weights_ledger*/ nullptr,
|
||||
cell_to_input_weights, cell_to_forget_weights,
|
||||
cell_to_output_weights, input_layer_norm_coefficients,
|
||||
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
||||
@ -1641,7 +1935,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, params,
|
||||
projection_weights, /*projection_weights_ledger*/ nullptr,
|
||||
projection_bias, params,
|
||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||
GetTemporary(context, node, kScratchBuffer),
|
||||
GetTemporary(context, node, kInputScalingFactors),
|
||||
|
@ -312,6 +312,7 @@ void CalculateLstmGateHybrid(
|
||||
// Input and weights
|
||||
const int8_t* input, const float* input_sf, const int32_t* input_zp,
|
||||
const int8_t* input_to_gate_weights,
|
||||
const uint8_t* input_to_gate_weights_ledger,
|
||||
const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
|
||||
// Aux input and weights
|
||||
const int8_t* aux_input, const float* aux_input_sf,
|
||||
@ -321,6 +322,7 @@ void CalculateLstmGateHybrid(
|
||||
// Output state and weights
|
||||
const int8_t* output_state, const float* output_state_sf,
|
||||
const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
|
||||
const uint8_t* recurrent_to_gate_weights_ledger,
|
||||
const float recurrent_to_gate_weights_scale,
|
||||
int32_t* recurrent_to_gate_row_sums,
|
||||
// Cell state and weights (peephole LSTM)
|
||||
@ -356,11 +358,22 @@ void CalculateLstmGateHybrid(
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
// Skip if input is all zeros.
|
||||
if (!is_input_all_zeros) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_gate_weights, n_cell, n_input, input,
|
||||
input_to_gate_weights_scale, input_sf, n_batch, gate,
|
||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch,
|
||||
input_to_gate_row_sums, compute_row_sums, scratch0, context);
|
||||
if (input_to_gate_weights_ledger != nullptr) {
|
||||
std::vector<float> scales(n_batch);
|
||||
for (int i = 0; i < n_batch; i++) {
|
||||
scales[i] = input_to_gate_weights_scale * input_sf[i];
|
||||
}
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
|
||||
input, scales.data(), n_batch, gate);
|
||||
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_gate_weights, n_cell, n_input, input,
|
||||
input_to_gate_weights_scale, input_sf, n_batch, gate,
|
||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch,
|
||||
input_to_gate_row_sums, compute_row_sums, scratch0, context);
|
||||
}
|
||||
}
|
||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||
// Skip if auxiliary input is not available or all zeros.
|
||||
@ -374,11 +387,21 @@ void CalculateLstmGateHybrid(
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
// Skip if output state is all zeros.
|
||||
if (!is_output_state_all_zeros) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_gate_weights, n_cell, n_output, output_state,
|
||||
recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
|
||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
|
||||
recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
|
||||
if (recurrent_to_gate_weights_ledger != nullptr) {
|
||||
std::vector<float> scales(n_batch);
|
||||
for (int i = 0; i < n_batch; i++) {
|
||||
scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
|
||||
}
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
|
||||
n_output, output_state, scales.data(), n_batch, gate);
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_gate_weights, n_cell, n_output, output_state,
|
||||
recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
|
||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
|
||||
recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
|
||||
}
|
||||
}
|
||||
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
|
||||
if (use_peephole) {
|
||||
@ -422,11 +445,12 @@ void CalculateLstmGateHybrid(
|
||||
void CalculateLstmOutputHybrid(
|
||||
int n_batch, int n_cell, int n_output, const float* cell_state,
|
||||
const float* output_gate, TfLiteFusedActivation activation,
|
||||
const int8_t* projection_weights, float projection_weights_scale,
|
||||
const float* projection_bias, const float proj_clip, float* output_state,
|
||||
bool asymmetric_quantize_inputs, int32_t* projection_weights_row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context, float* scratch0,
|
||||
int8_t* scratch1, float* scratch2, int32_t* scratch3, int32_t* scratch4) {
|
||||
const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
|
||||
float projection_weights_scale, const float* projection_bias,
|
||||
const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
|
||||
int32_t* projection_weights_row_sums, bool* compute_row_sums,
|
||||
CpuBackendContext* context, float* scratch0, int8_t* scratch1,
|
||||
float* scratch2, int32_t* scratch3, int32_t* scratch4) {
|
||||
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
|
||||
activation, scratch0);
|
||||
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
|
||||
@ -447,11 +471,21 @@ void CalculateLstmOutputHybrid(
|
||||
tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
|
||||
scratch2, scratch3,
|
||||
asymmetric_quantize_inputs);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights, n_output, n_cell, scratch1,
|
||||
projection_weights_scale, scratch2, n_batch, output_state,
|
||||
/*per_channel_scale=*/nullptr, scratch3, scratch4,
|
||||
projection_weights_row_sums, compute_row_sums, scratch2, context);
|
||||
if (projection_weights_ledger != nullptr) {
|
||||
std::vector<float> scales(n_batch);
|
||||
for (int i = 0; i < n_batch; i++) {
|
||||
scales[i] = projection_weights_scale * scratch2[i];
|
||||
}
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights, projection_weights_ledger, n_output, n_cell,
|
||||
scratch1, scales.data(), n_batch, output_state);
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights, n_output, n_cell, scratch1,
|
||||
projection_weights_scale, scratch2, n_batch, output_state,
|
||||
/*per_channel_scale=*/nullptr, scratch3, scratch4,
|
||||
projection_weights_row_sums, compute_row_sums, scratch2, context);
|
||||
}
|
||||
}
|
||||
if (proj_clip > 0.0f) {
|
||||
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
|
||||
@ -955,11 +989,16 @@ inline void LstmStepFloat(
|
||||
// output_ptr - size 'n_batch * output_batch_leading_dim'
|
||||
inline void LstmStepHybrid(
|
||||
const float* input_ptr, const int8_t* input_to_input_weights_ptr,
|
||||
const uint8_t* input_to_input_weights_ledger_ptr,
|
||||
float input_to_input_weights_scale,
|
||||
const int8_t* input_to_forget_weights_ptr,
|
||||
const uint8_t* input_to_forget_weights_ledger_ptr,
|
||||
float input_to_forget_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr,
|
||||
const uint8_t* input_to_cell_weights_ledger_ptr,
|
||||
float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_output_weights_ptr,
|
||||
const uint8_t* input_to_output_weights_ledger_ptr,
|
||||
float input_to_output_weights_scale, const float* aux_input_ptr,
|
||||
const int8_t* aux_input_to_input_weights_ptr,
|
||||
float aux_input_to_input_weights_scale,
|
||||
@ -970,12 +1009,16 @@ inline void LstmStepHybrid(
|
||||
const int8_t* aux_input_to_output_weights_ptr,
|
||||
float aux_input_to_output_weights_scale,
|
||||
const int8_t* recurrent_to_input_weights_ptr,
|
||||
const uint8_t* recurrent_to_input_weights_ledger_ptr,
|
||||
float recurrent_to_input_weights_scale,
|
||||
const int8_t* recurrent_to_forget_weights_ptr,
|
||||
const uint8_t* recurrent_to_forget_weights_ledger_ptr,
|
||||
float recurrent_to_forget_weights_scale,
|
||||
const int8_t* recurrent_to_cell_weights_ptr,
|
||||
const uint8_t* recurrent_to_cell_weights_ledger_ptr,
|
||||
float recurrent_to_cell_weights_scale,
|
||||
const int8_t* recurrent_to_output_weights_ptr,
|
||||
const uint8_t* recurrent_to_output_weights_ledger_ptr,
|
||||
float recurrent_to_output_weights_scale,
|
||||
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
@ -988,19 +1031,21 @@ inline void LstmStepHybrid(
|
||||
const float* output_layer_norm_coefficients_ptr,
|
||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
|
||||
const int8_t* projection_weights_ptr, float projection_weights_scale,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
int output_batch_leading_dim, float* scratch0, float* scratch1,
|
||||
float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf,
|
||||
float* output_state_sf, float* scaling_factors_scratch,
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr,
|
||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_output_scratch, float* output_state_ptr,
|
||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
||||
int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp,
|
||||
int32_t* row_sums, int row_sums_size, bool* compute_row_sums,
|
||||
bool asymmetric_quantize_inputs, CpuBackendContext* context) {
|
||||
const int8_t* projection_weights_ptr,
|
||||
const uint8_t* projection_weights_ledger_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||
float* scratch0, float* scratch1, float* scratch2, float* scratch3,
|
||||
float* input_sf, float* aux_input_sf, float* output_state_sf,
|
||||
float* scaling_factors_scratch, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
|
||||
int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
|
||||
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
|
||||
float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
|
||||
int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
|
||||
bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
||||
CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existence of only one to the get the condition.
|
||||
@ -1106,11 +1151,12 @@ inline void LstmStepHybrid(
|
||||
// Calculate the input gate. (If not CIFG.)
|
||||
CalculateLstmGateHybrid(
|
||||
quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
|
||||
input_to_input_weights_scale, input_to_input_row_sums,
|
||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
||||
aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale,
|
||||
aux_input_to_input_row_sums, quantized_output_state_ptr,
|
||||
output_state_sf, output_state_zp, recurrent_to_input_weights_ptr,
|
||||
input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
|
||||
input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||
aux_input_zp, aux_input_to_input_weights_ptr,
|
||||
aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
|
||||
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||
recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
|
||||
recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
|
||||
cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
||||
input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
|
||||
@ -1122,11 +1168,12 @@ inline void LstmStepHybrid(
|
||||
// Calculate the forget gate.
|
||||
CalculateLstmGateHybrid(
|
||||
quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
|
||||
input_to_forget_weights_scale, input_to_forget_row_sums,
|
||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
||||
aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale,
|
||||
aux_input_to_forget_row_sums, quantized_output_state_ptr, output_state_sf,
|
||||
output_state_zp, recurrent_to_forget_weights_ptr,
|
||||
input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
|
||||
input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||
aux_input_zp, aux_input_to_forget_weights_ptr,
|
||||
aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
|
||||
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
|
||||
recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
|
||||
cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
||||
forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
|
||||
@ -1137,11 +1184,12 @@ inline void LstmStepHybrid(
|
||||
// Calculate the cell update gate.
|
||||
CalculateLstmGateHybrid(
|
||||
quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
|
||||
input_to_cell_weights_scale, input_to_cell_row_sums,
|
||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
||||
aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale,
|
||||
aux_input_to_cell_row_sums, quantized_output_state_ptr, output_state_sf,
|
||||
output_state_zp, recurrent_to_cell_weights_ptr,
|
||||
input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
|
||||
input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||
aux_input_zp, aux_input_to_cell_weights_ptr,
|
||||
aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
|
||||
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
|
||||
recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
|
||||
/*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
|
||||
/*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
|
||||
@ -1157,11 +1205,12 @@ inline void LstmStepHybrid(
|
||||
// Calculate the output gate.
|
||||
CalculateLstmGateHybrid(
|
||||
quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
|
||||
input_to_output_weights_scale, input_to_output_row_sums,
|
||||
quantized_aux_input_ptr, aux_input_sf, aux_input_zp,
|
||||
aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale,
|
||||
aux_input_to_output_row_sums, quantized_output_state_ptr, output_state_sf,
|
||||
output_state_zp, recurrent_to_output_weights_ptr,
|
||||
input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
|
||||
input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
|
||||
aux_input_zp, aux_input_to_output_weights_ptr,
|
||||
aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
|
||||
quantized_output_state_ptr, output_state_sf, output_state_zp,
|
||||
recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
|
||||
recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
|
||||
cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
||||
output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
|
||||
@ -1172,11 +1221,11 @@ inline void LstmStepHybrid(
|
||||
// Update the output state.
|
||||
CalculateLstmOutputHybrid(
|
||||
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
|
||||
params->activation, projection_weights_ptr, projection_weights_scale,
|
||||
projection_bias_ptr, params->proj_clip, output_state_ptr,
|
||||
asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums,
|
||||
context, scratch2, quantized_output_scratch, input_sf, input_zp,
|
||||
accum_scratch_ptr);
|
||||
params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
|
||||
projection_weights_scale, projection_bias_ptr, params->proj_clip,
|
||||
output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
|
||||
compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
|
||||
input_zp, accum_scratch_ptr);
|
||||
// Copy output state to the output. Note that the output's rows may not be
|
||||
// contiguous (output_batch_leading_dim != n_output).
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
@ -1829,13 +1878,21 @@ TfLiteStatus EvalFloat(
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_input_weights_ledger,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_forget_weights_ledger,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_cell_weights_ledger,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* input_to_output_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights_ledger,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
@ -1850,9 +1907,11 @@ TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||
const TfLiteTensor* projection_weights,
|
||||
const TfLiteTensor* projection_weights_ledger,
|
||||
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||
bool forward_sequence, bool time_major, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
@ -1929,12 +1988,16 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<float>(output) + t_rel * output_step + output_offset;
|
||||
LstmStepHybrid(
|
||||
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||
GetTensorData<uint8_t>(input_to_input_weights_ledger),
|
||||
GetTensorScale(input_to_input_weights),
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
GetTensorData<uint8_t>(input_to_forget_weights_ledger),
|
||||
GetTensorScale(input_to_forget_weights),
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
GetTensorData<uint8_t>(input_to_cell_weights_ledger),
|
||||
GetTensorScale(input_to_cell_weights),
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
GetTensorData<uint8_t>(input_to_output_weights_ledger),
|
||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||
GetTensorScale(aux_input_to_input_weights),
|
||||
@ -1945,12 +2008,16 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||
GetTensorScale(aux_input_to_output_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
|
||||
GetTensorScale(recurrent_to_input_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
|
||||
GetTensorScale(recurrent_to_forget_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
|
||||
GetTensorScale(recurrent_to_cell_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
|
||||
GetTensorScale(recurrent_to_output_weights),
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
GetTensorScale(cell_to_input_weights),
|
||||
@ -1967,6 +2034,7 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<float>(cell_gate_bias),
|
||||
GetTensorData<float>(output_gate_bias),
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
GetTensorData<uint8_t>(projection_weights_ledger),
|
||||
GetTensorScale(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
@ -2018,12 +2086,16 @@ TfLiteStatus EvalHybrid(
|
||||
|
||||
LstmStepHybrid(
|
||||
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||
GetTensorData<uint8_t>(input_to_input_weights_ledger),
|
||||
GetTensorScale(input_to_input_weights),
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
GetTensorData<uint8_t>(input_to_forget_weights_ledger),
|
||||
GetTensorScale(input_to_forget_weights),
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
GetTensorData<uint8_t>(input_to_cell_weights_ledger),
|
||||
GetTensorScale(input_to_cell_weights),
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
GetTensorData<uint8_t>(input_to_output_weights_ledger),
|
||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||
GetTensorScale(aux_input_to_input_weights),
|
||||
@ -2034,12 +2106,16 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||
GetTensorScale(aux_input_to_output_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
|
||||
GetTensorScale(recurrent_to_input_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
|
||||
GetTensorScale(recurrent_to_forget_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
|
||||
GetTensorScale(recurrent_to_cell_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
|
||||
GetTensorScale(recurrent_to_output_weights),
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
GetTensorScale(cell_to_input_weights),
|
||||
@ -2056,6 +2132,7 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<float>(cell_gate_bias),
|
||||
GetTensorData<float>(output_gate_bias),
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
GetTensorData<uint8_t>(projection_weights_ledger),
|
||||
GetTensorScale(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params,
|
||||
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
||||
|
@ -125,13 +125,21 @@ TfLiteStatus EvalFloat(
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_input_weights_ledger,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_forget_weights_ledger,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_cell_weights_ledger,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* input_to_output_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights_ledger,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights_ledger,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
@ -146,9 +154,11 @@ TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||
const TfLiteTensor* projection_weights,
|
||||
const TfLiteTensor* projection_weights_ledger,
|
||||
const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
|
||||
bool forward_sequence, bool time_major, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
|
@ -906,14 +906,14 @@ void TestOneHybridAsymmLSTM() {
|
||||
constexpr float kDefaultScale = 18.0;
|
||||
ops::builtin::lstm_eval::EvalHybrid(
|
||||
one_parameter.GetFloatInput(),
|
||||
HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale),
|
||||
HybridLstmParam::addScale(one_parameter.Geti2i(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Geti2f(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Geti2c(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Geti2o(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Getr2i(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Getr2f(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Getr2c(), kDefaultScale), nullptr,
|
||||
HybridLstmParam::addScale(one_parameter.Getr2o(), kDefaultScale), nullptr,
|
||||
/*cell_to_input_weights=*/nullptr,
|
||||
/*cell_to_forget_weights=*/nullptr,
|
||||
/*cell_to_output_weights=*/nullptr, one_parameter.GetInputLayerNorm(),
|
||||
@ -926,7 +926,7 @@ void TestOneHybridAsymmLSTM() {
|
||||
/*aux_input_to_output_weights=*/nullptr, one_parameter.GetInputBias(),
|
||||
one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
|
||||
one_parameter.GetOutputBias(),
|
||||
HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0),
|
||||
HybridLstmParam::addScale(one_parameter.GetProjection(), 1.0), nullptr,
|
||||
one_parameter.GetProjectionBias(), ¶m,
|
||||
/*forward_sequence=*/true,
|
||||
/*time_major=*/true,
|
||||
|
@ -2114,6 +2114,620 @@ TEST(LstmOpTest, InvalidTypes) {
|
||||
}
|
||||
#endif
|
||||
|
||||
class HybridSparseLSTMOpModel : public ::tflite::SingleOpModel {
|
||||
public:
|
||||
HybridSparseLSTMOpModel(
|
||||
int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
|
||||
bool use_peephole, bool use_projection_weights, bool use_projection_bias,
|
||||
float cell_clip, float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorData& input_weights_td,
|
||||
const std::vector<float>& input_to_input_weights,
|
||||
const std::vector<float>& input_to_forget_weights,
|
||||
const std::vector<float>& input_to_cell_weights,
|
||||
const std::vector<float>& input_to_output_weights,
|
||||
const TensorData& recurrent_weights_td,
|
||||
const std::vector<float>& recurrent_to_input_weights,
|
||||
const std::vector<float>& recurrent_to_forget_weights,
|
||||
const std::vector<float>& recurrent_to_cell_weights,
|
||||
const std::vector<float>& recurrent_to_output_weights,
|
||||
const ::tflite::TensorType& weight_type = ::tflite::TensorType_INT8)
|
||||
: n_batch_(n_batch),
|
||||
n_input_(n_input),
|
||||
n_cell_(n_cell),
|
||||
n_output_(n_output) {
|
||||
input_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
|
||||
if (use_cifg) {
|
||||
input_to_input_weights_ = AddNullInput();
|
||||
} else {
|
||||
input_to_input_weights_ =
|
||||
AddConstSparseInput(input_weights_td, input_to_input_weights, true);
|
||||
}
|
||||
|
||||
input_to_forget_weights_ =
|
||||
AddConstSparseInput(input_weights_td, input_to_forget_weights, true);
|
||||
|
||||
input_to_cell_weights_ =
|
||||
AddConstSparseInput(input_weights_td, input_to_cell_weights, true);
|
||||
|
||||
input_to_output_weights_ =
|
||||
AddConstSparseInput(input_weights_td, input_to_output_weights, true);
|
||||
|
||||
if (use_cifg) {
|
||||
recurrent_to_input_weights_ = AddNullInput();
|
||||
} else {
|
||||
recurrent_to_input_weights_ = AddConstSparseInput(
|
||||
recurrent_weights_td, recurrent_to_input_weights, true);
|
||||
}
|
||||
|
||||
recurrent_to_forget_weights_ = AddConstSparseInput(
|
||||
recurrent_weights_td, recurrent_to_forget_weights, true);
|
||||
recurrent_to_cell_weights_ = AddConstSparseInput(
|
||||
recurrent_weights_td, recurrent_to_cell_weights, true);
|
||||
recurrent_to_output_weights_ = AddConstSparseInput(
|
||||
recurrent_weights_td, recurrent_to_output_weights, true);
|
||||
|
||||
if (use_peephole) {
|
||||
if (use_cifg) {
|
||||
cell_to_input_weights_ = AddNullInput();
|
||||
} else {
|
||||
cell_to_input_weights_ = AddInput(weight_type);
|
||||
}
|
||||
cell_to_forget_weights_ = AddInput(weight_type);
|
||||
cell_to_output_weights_ = AddInput(weight_type);
|
||||
} else {
|
||||
cell_to_input_weights_ = AddNullInput();
|
||||
cell_to_forget_weights_ = AddNullInput();
|
||||
cell_to_output_weights_ = AddNullInput();
|
||||
}
|
||||
|
||||
if (use_cifg) {
|
||||
input_gate_bias_ = AddNullInput();
|
||||
} else {
|
||||
input_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
}
|
||||
forget_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
cell_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
output_gate_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
|
||||
if (use_projection_weights) {
|
||||
projection_weights_ = AddInput(weight_type);
|
||||
if (use_projection_bias) {
|
||||
projection_bias_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
} else {
|
||||
projection_bias_ = AddNullInput();
|
||||
}
|
||||
} else {
|
||||
projection_weights_ = AddNullInput();
|
||||
projection_bias_ = AddNullInput();
|
||||
}
|
||||
|
||||
// Adding the 2 state tensors.
|
||||
output_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32,
|
||||
{n_output_ * n_batch_}},
|
||||
true);
|
||||
cell_state_ = AddInput(::tflite::TensorData{::tflite::TensorType_FLOAT32,
|
||||
{n_cell_ * n_batch_}},
|
||||
true);
|
||||
|
||||
if (use_cifg) {
|
||||
input_layer_norm_weights_ = AddNullInput();
|
||||
} else {
|
||||
input_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
}
|
||||
forget_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
cell_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
output_layer_norm_weights_ = AddInput(::tflite::TensorType_FLOAT32);
|
||||
|
||||
output_ = AddOutput(::tflite::TensorType_FLOAT32);
|
||||
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
||||
CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip,
|
||||
proj_clip, LSTMKernelType_FULL, false)
|
||||
.Union());
|
||||
BuildInterpreter(input_shapes);
|
||||
}
|
||||
|
||||
void SetCellToInputWeights(std::vector<float> f) {
|
||||
SignedSymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
|
||||
}
|
||||
|
||||
void SetCellToForgetWeights(std::vector<float> f) {
|
||||
SignedSymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
|
||||
}
|
||||
|
||||
void SetCellToOutputWeights(std::vector<float> f) {
|
||||
SignedSymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputLayerNormWeights(std::vector<float> f) {
|
||||
PopulateTensor(input_layer_norm_weights_, f);
|
||||
}
|
||||
|
||||
void SetForgetLayerNormWeights(std::vector<float> f) {
|
||||
PopulateTensor(forget_layer_norm_weights_, f);
|
||||
}
|
||||
|
||||
void SetCellLayerNormWeights(std::vector<float> f) {
|
||||
PopulateTensor(cell_layer_norm_weights_, f);
|
||||
}
|
||||
|
||||
void SetOutputLayerNormWeights(std::vector<float> f) {
|
||||
PopulateTensor(output_layer_norm_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputGateBias(std::vector<float> f) {
|
||||
PopulateTensor(input_gate_bias_, f);
|
||||
}
|
||||
|
||||
void SetForgetGateBias(std::vector<float> f) {
|
||||
PopulateTensor(forget_gate_bias_, f);
|
||||
}
|
||||
|
||||
void SetCellBias(std::vector<float> f) { PopulateTensor(cell_bias_, f); }
|
||||
|
||||
void SetOutputGateBias(std::vector<float> f) {
|
||||
PopulateTensor(output_gate_bias_, f);
|
||||
}
|
||||
|
||||
void SetProjectionWeights(std::vector<float> f) {
|
||||
SignedSymmetricQuantizeAndPopulate(projection_weights_, f);
|
||||
}
|
||||
|
||||
void SetProjectionBias(std::vector<float> f) {
|
||||
PopulateTensor(projection_bias_, f);
|
||||
}
|
||||
|
||||
void SetInput(int offset, const float* begin, const float* end) {
|
||||
PopulateTensor(input_, offset, const_cast<float*>(begin),
|
||||
const_cast<float*>(end));
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
int num_inputs() { return n_input_; }
|
||||
int num_outputs() { return n_output_; }
|
||||
int num_cells() { return n_cell_; }
|
||||
int num_batches() { return n_batch_; }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int input_to_input_weights_;
|
||||
int input_to_forget_weights_;
|
||||
int input_to_cell_weights_;
|
||||
int input_to_output_weights_;
|
||||
|
||||
int recurrent_to_input_weights_;
|
||||
int recurrent_to_forget_weights_;
|
||||
int recurrent_to_cell_weights_;
|
||||
int recurrent_to_output_weights_;
|
||||
|
||||
int cell_to_input_weights_;
|
||||
int cell_to_forget_weights_;
|
||||
int cell_to_output_weights_;
|
||||
|
||||
int input_layer_norm_weights_;
|
||||
int forget_layer_norm_weights_;
|
||||
int cell_layer_norm_weights_;
|
||||
int output_layer_norm_weights_;
|
||||
|
||||
int input_gate_bias_;
|
||||
int forget_gate_bias_;
|
||||
int cell_bias_;
|
||||
int output_gate_bias_;
|
||||
|
||||
int projection_weights_;
|
||||
int projection_bias_;
|
||||
|
||||
int output_state_;
|
||||
int cell_state_;
|
||||
|
||||
int output_;
|
||||
|
||||
int n_batch_;
|
||||
int n_input_;
|
||||
int n_cell_;
|
||||
int n_output_;
|
||||
};
|
||||
|
||||
class BaseSparseLstmTest : public ::testing::Test {
|
||||
protected:
|
||||
// Weights of the Sparse Layer Norm LSTM model. Some are optional.
|
||||
std::vector<float> input_to_input_weights_;
|
||||
std::vector<float> input_to_cell_weights_;
|
||||
std::vector<float> input_to_forget_weights_;
|
||||
std::vector<float> input_to_output_weights_;
|
||||
std::vector<float> input_gate_bias_;
|
||||
std::vector<float> cell_gate_bias_;
|
||||
std::vector<float> forget_gate_bias_;
|
||||
std::vector<float> output_gate_bias_;
|
||||
std::vector<float> recurrent_to_input_weights_;
|
||||
std::vector<float> recurrent_to_cell_weights_;
|
||||
std::vector<float> recurrent_to_forget_weights_;
|
||||
std::vector<float> recurrent_to_output_weights_;
|
||||
std::vector<float> cell_to_input_weights_;
|
||||
std::vector<float> cell_to_forget_weights_;
|
||||
std::vector<float> cell_to_output_weights_;
|
||||
std::vector<float> input_layer_norm_weights_;
|
||||
std::vector<float> forget_layer_norm_weights_;
|
||||
std::vector<float> cell_layer_norm_weights_;
|
||||
std::vector<float> output_layer_norm_weights_;
|
||||
std::vector<float> projection_weights_;
|
||||
|
||||
std::vector<int> input_to_input_weights_size_;
|
||||
std::vector<int> input_to_cell_weights_size_;
|
||||
std::vector<int> input_to_forget_weights_size_;
|
||||
std::vector<int> input_to_output_weights_size_;
|
||||
std::vector<int> recurrent_to_input_weights_size_;
|
||||
std::vector<int> recurrent_to_cell_weights_size_;
|
||||
std::vector<int> recurrent_to_forget_weights_size_;
|
||||
std::vector<int> recurrent_to_output_weights_size_;
|
||||
|
||||
int n_batch_;
|
||||
int n_input_;
|
||||
int n_cell_;
|
||||
int n_output_;
|
||||
float cell_clip_;
|
||||
float proj_clip_;
|
||||
|
||||
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
|
||||
std::vector<std::vector<float>> sparse_layer_norm_lstm_input_;
|
||||
|
||||
// Compares output up to tolerance to the result of the layer_norm_lstm given
|
||||
// the input.
|
||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||
const std::vector<std::vector<float>>& output,
|
||||
HybridSparseLSTMOpModel* sparse_layer_norm_lstm,
|
||||
float tolerance = 1e-5) {
|
||||
const int num_batches = input.size();
|
||||
EXPECT_GT(num_batches, 0);
|
||||
const int num_inputs = sparse_layer_norm_lstm->num_inputs();
|
||||
EXPECT_GT(num_inputs, 0);
|
||||
const int input_sequence_size = input[0].size() / num_inputs;
|
||||
EXPECT_GT(input_sequence_size, 0);
|
||||
for (int i = 0; i < input_sequence_size; ++i) {
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* batch_start = input[b].data() + i * num_inputs;
|
||||
const float* batch_end = batch_start + num_inputs;
|
||||
|
||||
sparse_layer_norm_lstm->SetInput(
|
||||
b * sparse_layer_norm_lstm->num_inputs(), batch_start, batch_end);
|
||||
}
|
||||
|
||||
sparse_layer_norm_lstm->Invoke();
|
||||
|
||||
const int num_outputs = sparse_layer_norm_lstm->num_outputs();
|
||||
std::vector<float> expected;
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
||||
const float* golden_end_batch = golden_start_batch + num_outputs;
|
||||
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
||||
}
|
||||
EXPECT_THAT(
|
||||
sparse_layer_norm_lstm->GetOutput(),
|
||||
ElementsAreArray(::tflite::ArrayFloatNear(expected, tolerance)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class NoCifgPeepholeProjectionNoClippingSparseLstmTest
|
||||
: public BaseSparseLstmTest {
|
||||
void SetUp() override {
|
||||
n_batch_ = 2;
|
||||
n_input_ = 48;
|
||||
n_cell_ = 4;
|
||||
n_output_ = 16;
|
||||
cell_clip_ = 0.0;
|
||||
proj_clip_ = 0.0;
|
||||
|
||||
/* clang-format off */
|
||||
input_to_input_weights_ = {
|
||||
/* 1st row */
|
||||
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 2nd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 3rd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 4th row */
|
||||
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||
input_to_input_weights_size_ = {4, 48};
|
||||
|
||||
input_to_forget_weights_ = {
|
||||
/* 1st row */
|
||||
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 2nd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 3rd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 4th row */
|
||||
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||
input_to_forget_weights_size_ = {4, 48};
|
||||
|
||||
input_to_cell_weights_ = {
|
||||
/* 1st row */
|
||||
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 2nd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 3rd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 4th row */
|
||||
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||
input_to_cell_weights_size_ = {4, 48};
|
||||
|
||||
input_to_output_weights_ = {
|
||||
/* 1st row */
|
||||
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 33.33, 34.34, 35.35, 36.36, 37.37, 38.38,
|
||||
39.39, 40.40, 41.41, 42.42, 43.43, 44.44, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 2nd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, -17.17, -18.18, -19.19, -20.2, -21.21, -22.22, -23.23, -24.24,
|
||||
-25.25, -26.26, -27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 3rd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22, 23.23, -24.24, 25.25,
|
||||
-26.26, 27.27, -28.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 4th row */
|
||||
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -33.33, 34.34, -35.35, 36.36, -37.37,
|
||||
38.38, -39.39, 40.40, -41.41, 42.42, -43.43, 44.44, 0.0, 0.0, 0.0, 0};
|
||||
input_to_output_weights_size_ = {4, 48};
|
||||
|
||||
input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
|
||||
|
||||
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
|
||||
|
||||
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
|
||||
|
||||
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
|
||||
|
||||
recurrent_to_input_weights_ = {
|
||||
-0.2, -0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 1st row
|
||||
0.1, -0.5, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 2nd row
|
||||
-0.2, -0.3, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 3rd row
|
||||
0.05, -0.2, -0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 4th row
|
||||
};
|
||||
recurrent_to_input_weights_size_ = {4, 16};
|
||||
|
||||
recurrent_to_cell_weights_ = {
|
||||
-0.3, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 1st row
|
||||
-0.3, 0.8, -0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 2nd row
|
||||
-0.2, 0.3, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 3rd row
|
||||
-0.6, -0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 4th row
|
||||
};
|
||||
recurrent_to_cell_weights_size_ = {4, 16};
|
||||
|
||||
recurrent_to_forget_weights_ = {
|
||||
-0.5, -0.3, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 1st row
|
||||
-0.2, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 2nd row
|
||||
0.9, 0.3, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 3rd row
|
||||
0.2, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 4th row
|
||||
};
|
||||
recurrent_to_forget_weights_size_ = {4, 16};
|
||||
|
||||
recurrent_to_output_weights_ = {
|
||||
0.3, -0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 1st row
|
||||
-0.2, -0.5, -0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 2nd row
|
||||
-0.2, -0.6, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 3rd row
|
||||
-0.4, -0.7, -0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, // 4th row
|
||||
};
|
||||
recurrent_to_output_weights_size_ = {4, 16};
|
||||
|
||||
cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
|
||||
|
||||
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
|
||||
|
||||
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
|
||||
|
||||
input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
|
||||
forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
|
||||
cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
|
||||
output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
|
||||
|
||||
projection_weights_ = {
|
||||
-0.1, 0.2, 0.01, -0.2, // 1st row
|
||||
0.1, 0.5, 0.3, 0.08, // 2nd row
|
||||
0.07, 0.2, -0.4, 0.2, // 3rd row
|
||||
0.0, 0.0, 0.0, 0.0, // 4th row
|
||||
0.0, 0.0, 0.0, 0.0, // 5th row
|
||||
0.0, 0.0, 0.0, 0.0, // 6th row
|
||||
0.0, 0.0, 0.0, 0.0, // 7th row
|
||||
0.0, 0.0, 0.0, 0.0, // 8th row
|
||||
0.0, 0.0, 0.0, 0.0, // 9th row
|
||||
0.0, 0.0, 0.0, 0.0, // 10th row
|
||||
0.0, 0.0, 0.0, 0.0, // 11th row
|
||||
0.0, 0.0, 0.0, 0.0, // 12th row
|
||||
0.0, 0.0, 0.0, 0.0, // 13th row
|
||||
0.0, 0.0, 0.0, 0.0, // 14th row
|
||||
0.0, 0.0, 0.0, 0.0, // 15th row
|
||||
0.0, 0.0, 0.0, 0.0, // 16th row
|
||||
0.0, 0.0, 0.0, 0.0, // 17th row
|
||||
0.0, 0.0, 0.0, 0.0, // 18th row
|
||||
};
|
||||
|
||||
sparse_layer_norm_lstm_input_ = {
|
||||
// Batch0: 2 (input_sequence_size) * 45 (n_input_)
|
||||
{
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
|
||||
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
|
||||
-1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
|
||||
0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // seq 1
|
||||
},
|
||||
// Batch1: 2 (input_sequence_size) * 45 (n_input_)
|
||||
{
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0,
|
||||
-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // seq 0
|
||||
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, -1.1, 0.0, 2.0, 0.0,
|
||||
-1.7, 0.0, 1.9, 0.0, -1.5, 0.0, 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3,
|
||||
0.0, 2.8, 0.0, -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||
1.0, -2.5, 0.7, -1.9, 0.2, -1.0, 1.0, -1.0, // seq 1
|
||||
},
|
||||
};
|
||||
/* clang-format on */
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingSparseLstmTest,
|
||||
HybridSparseLstmBlackBoxTest) {
|
||||
TensorData input_weight = {};
|
||||
input_weight.type = TensorType_FLOAT32;
|
||||
input_weight.shape = {4, 48};
|
||||
input_weight.traversal_order = {0, 1, 2};
|
||||
input_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
input_weight.block_map = {1};
|
||||
input_weight.block_size = {16};
|
||||
TensorData recurrent_weight = {};
|
||||
recurrent_weight.type = TensorType_FLOAT32;
|
||||
recurrent_weight.shape = {4, 16};
|
||||
recurrent_weight.traversal_order = {0, 1, 2};
|
||||
recurrent_weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
recurrent_weight.block_map = {1};
|
||||
recurrent_weight.block_size = {16};
|
||||
HybridSparseLSTMOpModel sparse_layer_norm_lstm(
|
||||
n_batch_, n_input_, n_cell_, n_output_,
|
||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, cell_clip_, proj_clip_,
|
||||
{
|
||||
{n_batch_, n_input_}, // input tensor
|
||||
|
||||
{input_to_input_weights_size_},
|
||||
{input_to_forget_weights_size_},
|
||||
{input_to_cell_weights_size_},
|
||||
{input_to_output_weights_size_},
|
||||
|
||||
{recurrent_to_input_weights_size_},
|
||||
{recurrent_to_forget_weights_size_},
|
||||
{recurrent_to_cell_weights_size_},
|
||||
{recurrent_to_output_weights_size_},
|
||||
|
||||
{n_cell_}, // cell_to_input_weight tensor
|
||||
{n_cell_}, // cell_to_forget_weight tensor
|
||||
{n_cell_}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell_}, // input_gate_bias tensor
|
||||
{n_cell_}, // forget_gate_bias tensor
|
||||
{n_cell_}, // cell_bias tensor
|
||||
{n_cell_}, // output_gate_bias tensor
|
||||
|
||||
{n_output_, n_cell_}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_output_ * n_batch_}, // output_state tensor
|
||||
{n_cell_ * n_batch_}, // cell_state tensor
|
||||
|
||||
{n_cell_}, // input_layer_norm_weight tensor
|
||||
{n_cell_}, // forget_layer_norm_weight tensor
|
||||
{n_cell_}, // cell_layer_norm_weight tensor
|
||||
{n_cell_}, // output_layer_norm_weight tensor
|
||||
},
|
||||
input_weight, input_to_input_weights_, input_to_forget_weights_,
|
||||
input_to_cell_weights_, input_to_output_weights_, recurrent_weight,
|
||||
recurrent_to_input_weights_, recurrent_to_forget_weights_,
|
||||
recurrent_to_cell_weights_, recurrent_to_output_weights_);
|
||||
|
||||
sparse_layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||
sparse_layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
sparse_layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
sparse_layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
sparse_layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
sparse_layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
sparse_layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
sparse_layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
|
||||
sparse_layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
|
||||
sparse_layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
|
||||
sparse_layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
|
||||
|
||||
sparse_layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
/* clang-format off */
|
||||
const std::vector<std::vector<float>> sparse_layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 2 (input_sequence_size) * 3 (n_output_)
|
||||
0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output_)
|
||||
0.0550758, 0.138464, -0.0628034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.069672, 0.195428, -0.0605584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
}};
|
||||
/* clang-format on */
|
||||
|
||||
VerifyGoldens(sparse_layer_norm_lstm_input_,
|
||||
sparse_layer_norm_lstm_golden_output, &sparse_layer_norm_lstm);
|
||||
}
|
||||
|
||||
// Test parameter controls asymmetric_quantize_inputs in LSTMOpModel.
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Parameterized, LstmOpTest,
|
||||
|
@ -214,9 +214,82 @@ class SingleOpModel {
|
||||
return AddConstInput(TensorData{type, shape}, data);
|
||||
}
|
||||
|
||||
// TODO(b/166202747): Use a better way to do type specialization. Reduce
|
||||
// duplicate code in the two functions below.
|
||||
int AddConstSparseInput(const TensorData& t,
|
||||
const std::vector<int8_t>& data) {
|
||||
int id = tensors_.size();
|
||||
const int dims_count = t.traversal_order.size();
|
||||
std::vector<int8_t> dense_data(data);
|
||||
|
||||
tflite::optimize::sparsity::FormatConverter<int8_t> converter(
|
||||
t.shape, t.traversal_order, t.format, t.block_size, t.block_map);
|
||||
converter.DenseToSparse(dense_data.data());
|
||||
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const auto sparse_data = converter.GetData();
|
||||
|
||||
// Build sparsity parameter.
|
||||
std::vector<flatbuffers::Offset<DimensionMetadata>> fb_dim_metadata(
|
||||
dims_count);
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
const int metadata_idx = 2 * i;
|
||||
if (i < t.shape.size() &&
|
||||
t.format[t.traversal_order[i]] == kTfLiteDimSparseCSR) {
|
||||
auto array_segments =
|
||||
CreateInt32Vector(builder_,
|
||||
builder_.CreateVector(dim_metadata[metadata_idx]))
|
||||
.Union();
|
||||
auto array_indices =
|
||||
CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1]))
|
||||
.Union();
|
||||
fb_dim_metadata[i] = CreateDimensionMetadata(
|
||||
builder_, DimensionType_SPARSE_CSR, 0,
|
||||
SparseIndexVector_Int32Vector, array_segments,
|
||||
SparseIndexVector_Int32Vector, array_indices);
|
||||
} else {
|
||||
fb_dim_metadata[i] = CreateDimensionMetadata(
|
||||
builder_, DimensionType_DENSE, dim_metadata[metadata_idx][0]);
|
||||
}
|
||||
}
|
||||
|
||||
flatbuffers::Offset<SparsityParameters> s_param = CreateSparsityParameters(
|
||||
builder_, builder_.CreateVector(t.traversal_order),
|
||||
builder_.CreateVector(t.block_map),
|
||||
builder_.CreateVector(fb_dim_metadata));
|
||||
|
||||
int buffer_id = 0;
|
||||
if (!data.empty()) {
|
||||
// Initialize buffers list with empty buffer to allow for non-const
|
||||
// tensors.
|
||||
if (buffers_.empty()) {
|
||||
buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
|
||||
}
|
||||
|
||||
// Add compressed data as a Buffer to buffers list.
|
||||
buffer_id = buffers_.size();
|
||||
auto data_buffer = builder_.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(sparse_data.data()),
|
||||
sparse_data.size());
|
||||
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||
}
|
||||
|
||||
tensors_.push_back(CreateTensor(
|
||||
builder_, builder_.CreateVector<int>(t.shape), t.type,
|
||||
/*buffer=*/buffer_id,
|
||||
/*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param));
|
||||
|
||||
inputs_.push_back(id);
|
||||
tensor_data_[id] = t;
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
// Add a constant sparse tensor as input.
|
||||
template <typename T>
|
||||
int AddConstSparseInput(const TensorData& t, std::initializer_list<T> data) {
|
||||
int AddConstSparseInput(const TensorData& t, const std::vector<T>& data,
|
||||
bool symmetric_quantize = false) {
|
||||
int id = tensors_.size();
|
||||
const int dims_count = t.traversal_order.size();
|
||||
std::vector<T> dense_data(data);
|
||||
@ -258,8 +331,9 @@ class SingleOpModel {
|
||||
builder_.CreateVector(t.block_map),
|
||||
builder_.CreateVector(fb_dim_metadata));
|
||||
|
||||
flatbuffers::Offset<QuantizationParameters> q_params = 0;
|
||||
int buffer_id = 0;
|
||||
if (data.size()) {
|
||||
if (!data.empty()) {
|
||||
// Initialize buffers list with empty buffer to allow for non-const
|
||||
// tensors.
|
||||
if (buffers_.empty()) {
|
||||
@ -268,16 +342,31 @@ class SingleOpModel {
|
||||
|
||||
// Add compressed data as a Buffer to buffers list.
|
||||
buffer_id = buffers_.size();
|
||||
auto data_buffer = builder_.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(sparse_data.data()),
|
||||
sizeof(T) * sparse_data.size());
|
||||
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||
if (symmetric_quantize) {
|
||||
const int length = sparse_data.size();
|
||||
std::vector<int8_t> q(length);
|
||||
float min, max, scaling_factor;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
sparse_data.data(), length, q.data(), &min, &max, &scaling_factor);
|
||||
q_params = CreateQuantizationParameters(
|
||||
builder_, 0, 0, builder_.CreateVector<float>({scaling_factor}),
|
||||
builder_.CreateVector<int64_t>({0}));
|
||||
auto data_buffer = builder_.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(q.data()), q.size());
|
||||
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||
} else {
|
||||
auto data_buffer = builder_.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(sparse_data.data()),
|
||||
sizeof(T) * sparse_data.size());
|
||||
buffers_.push_back(CreateBuffer(builder_, data_buffer));
|
||||
}
|
||||
}
|
||||
|
||||
tensors_.push_back(CreateTensor(
|
||||
builder_, builder_.CreateVector<int>(t.shape), t.type,
|
||||
/*buffer=*/buffer_id,
|
||||
/*name=*/0, /*quantization=*/0, /*is_variable=*/false, s_param));
|
||||
tensors_.push_back(
|
||||
CreateTensor(builder_, builder_.CreateVector<int>(t.shape),
|
||||
symmetric_quantize ? TensorType_INT8 : t.type,
|
||||
/*buffer=*/buffer_id,
|
||||
/*name=*/0, q_params, /*is_variable=*/false, s_param));
|
||||
|
||||
inputs_.push_back(id);
|
||||
tensor_data_[id] = t;
|
||||
|
@ -650,11 +650,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
const int row_sums_size = row_sums->dims->data[0];
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
input, input_to_input_weights,
|
||||
/*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
|
||||
/*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
|
||||
/*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
|
||||
/*input_to_output_weights_ledger*/ nullptr,
|
||||
recurrent_to_input_weights,
|
||||
/*recurrent_to_input_weights_ledger*/ nullptr,
|
||||
recurrent_to_forget_weights,
|
||||
/*recurrent_to_forget_weights_ledger*/ nullptr,
|
||||
recurrent_to_cell_weights,
|
||||
/*recurrent_to_cell_weights_ledger*/ nullptr,
|
||||
recurrent_to_output_weights,
|
||||
/*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights,
|
||||
cell_to_forget_weights, cell_to_output_weights,
|
||||
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
||||
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
||||
/*aux_input=*/nullptr,
|
||||
@ -663,7 +672,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, &lstm_params,
|
||||
projection_weights, /*projection_weights_ledger*/ nullptr,
|
||||
projection_bias, &lstm_params,
|
||||
/*forward_sequence=*/true, time_major,
|
||||
/*output_offset=*/0, scratch_buffer,
|
||||
GetTemporary(context, node, kInputScalingFactors),
|
||||
|
Loading…
Reference in New Issue
Block a user