Use eigen for lstm_eval activations.
PiperOrigin-RevId: 247544281
This commit is contained in:
parent
127471f213
commit
fc61ca2d4f
@ -389,11 +389,13 @@ cc_library(
|
|||||||
srcs = ["lstm_eval.cc"],
|
srcs = ["lstm_eval.cc"],
|
||||||
hdrs = ["lstm_eval.h"],
|
hdrs = ["lstm_eval.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":kernel_util",
|
||||||
":op_macros",
|
":op_macros",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/kernels:kernel_util",
|
|
||||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"@gemmlowp",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +16,12 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#ifdef GEMMLOWP_PROFILING
|
||||||
|
#include "profiling/profiler.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
@ -119,6 +125,9 @@ inline void LstmStepWithAuxInput(
|
|||||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||||
float* output_ptr_batch) {
|
float* output_ptr_batch) {
|
||||||
|
#ifdef GEMMLOWP_PROFILING
|
||||||
|
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputFloat");
|
||||||
|
#endif
|
||||||
// Since we have already checked that weights are all there or none, we can
|
// Since we have already checked that weights are all there or none, we can
|
||||||
// check the existence of only one to the get the condition.
|
// check the existence of only one to the get the condition.
|
||||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||||
@ -362,6 +371,28 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ApplyActivationsToVector(float* input, int input_size,
|
||||||
|
TfLiteFusedActivation activation_type,
|
||||||
|
float* output) {
|
||||||
|
using VectorMap = Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 1>>;
|
||||||
|
VectorMap input_map(input, input_size, 1);
|
||||||
|
VectorMap output_map(output, input_size, 1);
|
||||||
|
switch (activation_type) {
|
||||||
|
case kTfLiteActSigmoid: {
|
||||||
|
output_map.array() = input_map.array().logistic();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kTfLiteActTanh: {
|
||||||
|
output_map.array() = input_map.array().tanh();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
tensor_utils::ApplyActivationToVector(input, input_size, activation_type,
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Same as above but with quantized weight matrices. In detail:
|
// Same as above but with quantized weight matrices. In detail:
|
||||||
// Input of size 'n_batch * n_input':
|
// Input of size 'n_batch * n_input':
|
||||||
// input_ptr_batch
|
// input_ptr_batch
|
||||||
@ -473,6 +504,9 @@ inline void LstmStepWithAuxInput(
|
|||||||
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||||
float* cell_state_ptr, float* output_ptr_batch) {
|
float* cell_state_ptr, float* output_ptr_batch) {
|
||||||
|
#ifdef GEMMLOWP_PROFILING
|
||||||
|
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputHybrid");
|
||||||
|
#endif
|
||||||
// Since we have already checked that weights are all there or none, we
|
// Since we have already checked that weights are all there or none, we
|
||||||
// can check the existence of only one to the get the condition.
|
// can check the existence of only one to the get the condition.
|
||||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||||
@ -674,8 +708,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
|
||||||
input_gate_scratch);
|
input_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch,
|
||||||
input_gate_scratch);
|
kTfLiteActSigmoid, input_gate_scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each batch and cell: update forget gate.
|
// For each batch and cell: update forget gate.
|
||||||
@ -697,8 +731,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch,
|
||||||
forget_gate_scratch);
|
kTfLiteActSigmoid, forget_gate_scratch);
|
||||||
|
|
||||||
// For each batch and cell: update the cell.
|
// For each batch and cell: update the cell.
|
||||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||||
@ -712,8 +746,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation,
|
||||||
params->activation, cell_scratch);
|
cell_scratch);
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
@ -749,10 +783,10 @@ inline void LstmStepWithAuxInput(
|
|||||||
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell,
|
||||||
output_gate_scratch);
|
kTfLiteActSigmoid, output_gate_scratch);
|
||||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation,
|
||||||
params->activation, cell_scratch);
|
cell_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||||
n_batch * n_cell, output_gate_scratch);
|
n_batch * n_cell, output_gate_scratch);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user