diff --git a/tensorflow/lite/kernels/elementwise.cc b/tensorflow/lite/kernels/elementwise.cc index d3d8be666d3..1cc188ae5f7 100644 --- a/tensorflow/lite/kernels/elementwise.cc +++ b/tensorflow/lite/kernels/elementwise.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include - #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -27,11 +25,6 @@ namespace builtin { namespace elementwise { namespace { -enum KernelType { - kReference, - kGenericOptimized, -}; - bool IsNumericSupportedType(const TfLiteType type) { return type == kTfLiteFloat32; } @@ -86,40 +79,15 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) { return EvalNumeric(context, node, std::abs); } -typedef void (*UnaryOp)(const float* in, int n, float* out); - -inline TfLiteStatus EvalNumericOptimized(TfLiteContext* context, - TfLiteNode* node, UnaryOp op) { - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); - const int64_t num_elements = NumElements(input); - const float* in_data = GetTensorData(input); - float* out_data = GetTensorData(output); - op(in_data, num_elements, out_data); - return kTfLiteOk; -} - -template TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { - if (kernel_type == kGenericOptimized) { - return EvalNumericOptimized(context, node, optimized_ops::Sin); - } return EvalNumeric(context, node, std::sin); } -template TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { - if (kernel_type == kGenericOptimized) { - return EvalNumericOptimized(context, node, optimized_ops::Cos); - } return EvalNumeric(context, node, std::cos); } -template TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { - if (kernel_type == kGenericOptimized) { - return EvalNumericOptimized(context, node, optimized_ops::Log); - } return EvalNumeric(context, node, std::log); } @@ -150,60 +118,30 @@ TfLiteRegistration* Register_ABS() { return &r; } -TfLiteRegistration* Register_SIN_REF() { +TfLiteRegistration* Register_SIN() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + elementwise::SinEval}; return &r; } -TfLiteRegistration* Register_SIN_GENERIC_OPT() { +TfLiteRegistration* Register_COS() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + elementwise::CosEval}; return &r; } -TfLiteRegistration* Register_SIN() { return Register_SIN_GENERIC_OPT(); } - -TfLiteRegistration* Register_COS_REF() { +TfLiteRegistration* Register_LOG() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, elementwise::GenericPrepare, - elementwise::CosEval}; + elementwise::LogEval}; return &r; } -TfLiteRegistration* Register_COS_GENERIC_OPT() { - static TfLiteRegistration r = { - /*init=*/nullptr, /*free=*/nullptr, - elementwise::GenericPrepare, - elementwise::CosEval}; - return &r; -} - -TfLiteRegistration* Register_COS() { return Register_COS_GENERIC_OPT(); } - -TfLiteRegistration* Register_LOG_REF() { - static TfLiteRegistration r = { - /*init=*/nullptr, /*free=*/nullptr, - elementwise::GenericPrepare, - elementwise::LogEval}; - return &r; -} - -TfLiteRegistration* Register_LOG_GENERIC_OPT() { - static TfLiteRegistration r = { - /*init=*/nullptr, /*free=*/nullptr, - elementwise::GenericPrepare, - elementwise::LogEval}; - return &r; -} - -TfLiteRegistration* Register_LOG() { return Register_LOG_GENERIC_OPT(); } - TfLiteRegistration* Register_SQRT() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, diff --git a/tensorflow/lite/kernels/exp.cc b/tensorflow/lite/kernels/exp.cc index 05624edf5d9..607b398ebd7 100644 --- a/tensorflow/lite/kernels/exp.cc +++ b/tensorflow/lite/kernels/exp.cc @@ -13,12 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include - #include - #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -32,7 +29,6 @@ namespace exp { // This file has reference implementation of Exp. enum KernelType { kReference, - kGenericOptimized, }; struct ExpContext { @@ -57,8 +53,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ExpContext op_context(context, node); - switch (op_context.input->type) { - case kTfLiteFloat32: + #define TF_LITE_EXP(kernel_type, data_type) \ kernel_type::Exp(GetTensorData(op_context.input), \ NumElements(op_context.input), \ @@ -66,18 +61,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. if (kernel_type == kReference) { - TF_LITE_EXP(reference_ops, float); - } else { - TF_LITE_EXP(optimized_ops, float); + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_EXP(reference_ops, float); + break; + default: + context->ReportError(context, + "Type %d is currently not supported by Exp.", + op_context.input->type); + return kTfLiteError; + } } #undef TF_LITE_EXP - break; - default: - context->ReportError(context, - "Type %d is currently not supported by Exp.", - op_context.input->type); - return kTfLiteError; - } return kTfLiteOk; } @@ -89,13 +84,8 @@ TfLiteRegistration* Register_EXP_REF() { return &r; } -TfLiteRegistration* Register_EXP_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare, - exp::Eval}; - return &r; -} - -TfLiteRegistration* Register_EXP() { return Register_EXP_GENERIC_OPT(); } +// TODO(kanlig): add optimized implementation of Exp. +TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 048ea55889c..ba7b0fd2f32 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -3917,42 +3917,6 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, output_map.array() = input_map.array().template cast(); } -template -inline void ElementWise(const T* input_data, const int buffer_size, - T* output_data) { - auto input_map = VectorMap(input_data, buffer_size, 1); - auto output_map = VectorMap(output_data, buffer_size, 1); - output_map.array() = input_map.array().unaryExpr(ScalarOp()); -} - -template -inline void Cos(const T* input_data, const int buffer_size, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Cos"); - ElementWise>(input_data, buffer_size, - output_data); -} - -template -inline void Exp(const T* input_data, const int buffer_size, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Exp"); - return ElementWise>( - input_data, buffer_size, output_data); -} - -template -inline void Log(const T* input_data, const int buffer_size, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Log"); - ElementWise>(input_data, buffer_size, - output_data); -} - -template -inline void Sin(const T* input_data, const int buffer_size, T* output_data) { - gemmlowp::ScopedProfilingLabel label("Sin"); - ElementWise>(input_data, buffer_size, - output_data); -} - inline void Floor(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Floor"); diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 57c10510954..6d10828869f 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -78,9 +78,9 @@ TfLiteRegistration* Register_SPLIT(); TfLiteRegistration* Register_SPLIT_V(); TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE_REF(); -TfLiteRegistration* Register_EXP_REF(); +TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); -TfLiteRegistration* Register_LOG_REF(); +TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_SOFTMAX_REF(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -103,7 +103,7 @@ TfLiteRegistration* Register_REDUCE_MIN(); TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE_REF(); -TfLiteRegistration* Register_SIN_REF(); +TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSECONV_REF(); TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); @@ -229,9 +229,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF()); - AddBuiltin(BuiltinOperator_EXP, Register_EXP_REF()); + AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); - AddBuiltin(BuiltinOperator_LOG, Register_LOG_REF()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), @@ -250,7 +250,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF()); - AddBuiltin(BuiltinOperator_SIN, Register_SIN_REF()); + AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_SUM, Register_SUM());