Use eigen for lstm_eval activations.

PiperOrigin-RevId: 247544281
This commit is contained in:
Renjie Liu 2019-05-09 20:21:06 -07:00 committed by TensorFlower Gardener
parent 127471f213
commit fc61ca2d4f
2 changed files with 47 additions and 11 deletions

View File

@ -389,11 +389,13 @@ cc_library(
srcs = ["lstm_eval.cc"],
hdrs = ["lstm_eval.h"],
deps = [
":kernel_util",
":op_macros",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:tensor_utils",
"//third_party/eigen3",
"@gemmlowp",
],
)

View File

@ -16,6 +16,12 @@ limitations under the License.
#include <cstdint>
#ifdef GEMMLOWP_PROFILING
#include "profiling/profiler.h"
#endif
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
@ -119,6 +125,9 @@ inline void LstmStepWithAuxInput(
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputFloat");
#endif
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@ -362,6 +371,28 @@ inline void LstmStepWithAuxInput(
}
}
void ApplyActivationsToVector(float* input, int input_size,
TfLiteFusedActivation activation_type,
float* output) {
using VectorMap = Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 1>>;
VectorMap input_map(input, input_size, 1);
VectorMap output_map(output, input_size, 1);
switch (activation_type) {
case kTfLiteActSigmoid: {
output_map.array() = input_map.array().logistic();
break;
}
case kTfLiteActTanh: {
output_map.array() = input_map.array().tanh();
break;
}
default: {
tensor_utils::ApplyActivationToVector(input, input_size, activation_type,
output);
}
}
}
// Same as above but with quantized weight matrices. In detail:
// Input of size 'n_batch * n_input':
// input_ptr_batch
@ -473,6 +504,9 @@ inline void LstmStepWithAuxInput(
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputHybrid");
#endif
// Since we have already checked that weights are all there or none, we
// can check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@ -674,8 +708,8 @@ inline void LstmStepWithAuxInput(
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch,
kTfLiteActSigmoid, input_gate_scratch);
}
// For each batch and cell: update forget gate.
@ -697,8 +731,8 @@ inline void LstmStepWithAuxInput(
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch,
kTfLiteActSigmoid, forget_gate_scratch);
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
@ -712,8 +746,8 @@ inline void LstmStepWithAuxInput(
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation,
cell_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
@ -749,10 +783,10 @@ inline void LstmStepWithAuxInput(
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell,
kTfLiteActSigmoid, output_gate_scratch);
ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation,
cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);