diff --git a/tensorflow/lite/kernels/elementwise.cc b/tensorflow/lite/kernels/elementwise.cc index 1cc188ae5f7..d3d8be666d3 100644 --- a/tensorflow/lite/kernels/elementwise.cc +++ b/tensorflow/lite/kernels/elementwise.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include + #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -25,6 +27,11 @@ namespace builtin { namespace elementwise { namespace { +enum KernelType { + kReference, + kGenericOptimized, +}; + bool IsNumericSupportedType(const TfLiteType type) { return type == kTfLiteFloat32; } @@ -79,15 +86,40 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) { return EvalNumeric(context, node, std::abs); } +typedef void (*UnaryOp)(const float* in, int n, float* out); + +inline TfLiteStatus EvalNumericOptimized(TfLiteContext* context, + TfLiteNode* node, UnaryOp op) { + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + const int64_t num_elements = NumElements(input); + const float* in_data = GetTensorData(input); + float* out_data = GetTensorData(output); + op(in_data, num_elements, out_data); + return kTfLiteOk; +} + +template TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + if (kernel_type == kGenericOptimized) { + return EvalNumericOptimized(context, node, optimized_ops::Sin); + } return EvalNumeric(context, node, std::sin); } +template TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { + if (kernel_type == kGenericOptimized) { + return EvalNumericOptimized(context, node, optimized_ops::Cos); + } return EvalNumeric(context, node, std::cos); } +template TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + if (kernel_type == kGenericOptimized) { + return EvalNumericOptimized(context, node, optimized_ops::Log); + } return EvalNumeric(context, node, std::log); } @@ -118,30 +150,60 @@ TfLiteRegistration* Register_ABS() { return &r; } -TfLiteRegistration* Register_SIN() { +TfLiteRegistration* Register_SIN_REF() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + elementwise::SinEval}; return &r; } -TfLiteRegistration* Register_COS() { +TfLiteRegistration* Register_SIN_GENERIC_OPT() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, elementwise::GenericPrepare, - elementwise::CosEval}; + elementwise::SinEval}; return &r; } -TfLiteRegistration* Register_LOG() { +TfLiteRegistration* Register_SIN() { return Register_SIN_GENERIC_OPT(); } + +TfLiteRegistration* Register_COS_REF() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, elementwise::GenericPrepare, - elementwise::LogEval}; + elementwise::CosEval}; return &r; } +TfLiteRegistration* Register_COS_GENERIC_OPT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::CosEval}; + return &r; +} + +TfLiteRegistration* Register_COS() { return Register_COS_GENERIC_OPT(); } + +TfLiteRegistration* Register_LOG_REF() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + +TfLiteRegistration* Register_LOG_GENERIC_OPT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + +TfLiteRegistration* Register_LOG() { return Register_LOG_GENERIC_OPT(); } + TfLiteRegistration* Register_SQRT() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, diff --git a/tensorflow/lite/kernels/exp.cc b/tensorflow/lite/kernels/exp.cc index 607b398ebd7..05624edf5d9 100644 --- a/tensorflow/lite/kernels/exp.cc +++ b/tensorflow/lite/kernels/exp.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include + #include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -29,6 +32,7 @@ namespace exp { // This file has reference implementation of Exp. enum KernelType { kReference, + kGenericOptimized, }; struct ExpContext { @@ -53,7 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ExpContext op_context(context, node); - + switch (op_context.input->type) { + case kTfLiteFloat32: #define TF_LITE_EXP(kernel_type, data_type) \ kernel_type::Exp(GetTensorData(op_context.input), \ NumElements(op_context.input), \ @@ -61,18 +66,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. if (kernel_type == kReference) { - switch (op_context.input->type) { - case kTfLiteFloat32: - TF_LITE_EXP(reference_ops, float); - break; - default: - context->ReportError(context, - "Type %d is currently not supported by Exp.", - op_context.input->type); - return kTfLiteError; - } + TF_LITE_EXP(reference_ops, float); + } else { + TF_LITE_EXP(optimized_ops, float); } #undef TF_LITE_EXP + break; + default: + context->ReportError(context, + "Type %d is currently not supported by Exp.", + op_context.input->type); + return kTfLiteError; + } return kTfLiteOk; } @@ -84,8 +89,13 @@ TfLiteRegistration* Register_EXP_REF() { return &r; } -// TODO(kanlig): add optimized implementation of Exp. -TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); } +TfLiteRegistration* Register_EXP_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare, + exp::Eval}; + return &r; +} + +TfLiteRegistration* Register_EXP() { return Register_EXP_GENERIC_OPT(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index ba7b0fd2f32..048ea55889c 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -3917,6 +3917,42 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, output_map.array() = input_map.array().template cast(); } +template +inline void ElementWise(const T* input_data, const int buffer_size, + T* output_data) { + auto input_map = VectorMap(input_data, buffer_size, 1); + auto output_map = VectorMap(output_data, buffer_size, 1); + output_map.array() = input_map.array().unaryExpr(ScalarOp()); +} + +template +inline void Cos(const T* input_data, const int buffer_size, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Cos"); + ElementWise>(input_data, buffer_size, + output_data); +} + +template +inline void Exp(const T* input_data, const int buffer_size, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Exp"); + return ElementWise>( + input_data, buffer_size, output_data); +} + +template +inline void Log(const T* input_data, const int buffer_size, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Log"); + ElementWise>(input_data, buffer_size, + output_data); +} + +template +inline void Sin(const T* input_data, const int buffer_size, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Sin"); + ElementWise>(input_data, buffer_size, + output_data); +} + inline void Floor(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Floor"); diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 6d10828869f..57c10510954 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -78,9 +78,9 @@ TfLiteRegistration* Register_SPLIT(); TfLiteRegistration* Register_SPLIT_V(); TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE_REF(); -TfLiteRegistration* Register_EXP(); +TfLiteRegistration* Register_EXP_REF(); TfLiteRegistration* Register_TOPK_V2(); -TfLiteRegistration* Register_LOG(); +TfLiteRegistration* Register_LOG_REF(); TfLiteRegistration* Register_LOG_SOFTMAX_REF(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -103,7 +103,7 @@ TfLiteRegistration* Register_REDUCE_MIN(); TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE_REF(); -TfLiteRegistration* Register_SIN(); +TfLiteRegistration* Register_SIN_REF(); TfLiteRegistration* Register_TRANSPOSECONV_REF(); TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); @@ -229,9 +229,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF()); - AddBuiltin(BuiltinOperator_EXP, Register_EXP()); + AddBuiltin(BuiltinOperator_EXP, Register_EXP_REF()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); - AddBuiltin(BuiltinOperator_LOG, Register_LOG()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG_REF()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), @@ -250,7 +250,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF()); - AddBuiltin(BuiltinOperator_SIN, Register_SIN()); + AddBuiltin(BuiltinOperator_SIN, Register_SIN_REF()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_SUM, Register_SUM());