diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 3d0b8ba4220..d28786e7ad1 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -389,11 +389,13 @@ cc_library( srcs = ["lstm_eval.cc"], hdrs = ["lstm_eval.h"], deps = [ + ":kernel_util", ":op_macros", "//tensorflow/lite/c:c_api_internal", - "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:tensor_utils", + "//third_party/eigen3", + "@gemmlowp", ], ) diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 2d0aee19ecb..a518daf2cfd 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -16,6 +16,12 @@ limitations under the License. #include +#ifdef GEMMLOWP_PROFILING +#include "profiling/profiler.h" +#endif + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -119,6 +125,9 @@ inline void LstmStepWithAuxInput( float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputFloat"); +#endif // Since we have already checked that weights are all there or none, we can // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); @@ -362,6 +371,28 @@ inline void LstmStepWithAuxInput( } } +void ApplyActivationsToVector(float* input, int input_size, + TfLiteFusedActivation activation_type, + float* output) { + using VectorMap = Eigen::Map>; + VectorMap input_map(input, input_size, 1); + VectorMap output_map(output, input_size, 1); + switch (activation_type) { + case kTfLiteActSigmoid: { + output_map.array() = input_map.array().logistic(); + break; + } + case kTfLiteActTanh: { + output_map.array() = input_map.array().tanh(); + break; + } + default: { + tensor_utils::ApplyActivationToVector(input, input_size, activation_type, + output); + } + } +} + // Same as above but with quantized weight matrices. In detail: // Input of size 'n_batch * n_input': // input_ptr_batch @@ -473,6 +504,9 @@ inline void LstmStepWithAuxInput( int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputHybrid"); +#endif // Since we have already checked that weights are all there or none, we // can check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); @@ -674,8 +708,8 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, input_gate_scratch); } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); + ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch, + kTfLiteActSigmoid, input_gate_scratch); } // For each batch and cell: update forget gate. @@ -697,8 +731,8 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, forget_gate_scratch); } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); + ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch, + kTfLiteActSigmoid, forget_gate_scratch); // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, @@ -712,8 +746,8 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, cell_scratch); } - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); + ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation, + cell_scratch); if (use_cifg) { tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch); @@ -749,10 +783,10 @@ inline void LstmStepWithAuxInput( tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, output_gate_scratch); } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); + ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell, + kTfLiteActSigmoid, output_gate_scratch); + ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation, + cell_scratch); tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell, output_gate_scratch);