Use eigen for lstm_eval activations.
PiperOrigin-RevId: 247544281
This commit is contained in:
parent
127471f213
commit
fc61ca2d4f
@ -389,11 +389,13 @@ cc_library(
|
||||
srcs = ["lstm_eval.cc"],
|
||||
hdrs = ["lstm_eval.h"],
|
||||
deps = [
|
||||
":kernel_util",
|
||||
":op_macros",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//third_party/eigen3",
|
||||
"@gemmlowp",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,6 +16,12 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef GEMMLOWP_PROFILING
|
||||
#include "profiling/profiler.h"
|
||||
#endif
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
@ -119,6 +125,9 @@ inline void LstmStepWithAuxInput(
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch) {
|
||||
#ifdef GEMMLOWP_PROFILING
|
||||
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputFloat");
|
||||
#endif
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
@ -362,6 +371,28 @@ inline void LstmStepWithAuxInput(
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyActivationsToVector(float* input, int input_size,
|
||||
TfLiteFusedActivation activation_type,
|
||||
float* output) {
|
||||
using VectorMap = Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 1>>;
|
||||
VectorMap input_map(input, input_size, 1);
|
||||
VectorMap output_map(output, input_size, 1);
|
||||
switch (activation_type) {
|
||||
case kTfLiteActSigmoid: {
|
||||
output_map.array() = input_map.array().logistic();
|
||||
break;
|
||||
}
|
||||
case kTfLiteActTanh: {
|
||||
output_map.array() = input_map.array().tanh();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
tensor_utils::ApplyActivationToVector(input, input_size, activation_type,
|
||||
output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Same as above but with quantized weight matrices. In detail:
|
||||
// Input of size 'n_batch * n_input':
|
||||
// input_ptr_batch
|
||||
@ -473,6 +504,9 @@ inline void LstmStepWithAuxInput(
|
||||
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr_batch) {
|
||||
#ifdef GEMMLOWP_PROFILING
|
||||
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputHybrid");
|
||||
#endif
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
@ -674,8 +708,8 @@ inline void LstmStepWithAuxInput(
|
||||
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch,
|
||||
kTfLiteActSigmoid, input_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: update forget gate.
|
||||
@ -697,8 +731,8 @@ inline void LstmStepWithAuxInput(
|
||||
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
kTfLiteActSigmoid, forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
@ -712,8 +746,8 @@ inline void LstmStepWithAuxInput(
|
||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
}
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation,
|
||||
cell_scratch);
|
||||
if (use_cifg) {
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
forget_gate_scratch);
|
||||
@ -749,10 +783,10 @@ inline void LstmStepWithAuxInput(
|
||||
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell,
|
||||
kTfLiteActSigmoid, output_gate_scratch);
|
||||
ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||
n_batch * n_cell, output_gate_scratch);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user