Add optimized versions of numeric unary ops cos, exp, log, sin
PiperOrigin-RevId: 266885222
This commit is contained in:
parent
218836e14c
commit
185a465225
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -25,6 +27,11 @@ namespace builtin {
|
||||
namespace elementwise {
|
||||
namespace {
|
||||
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
bool IsNumericSupportedType(const TfLiteType type) {
|
||||
return type == kTfLiteFloat32;
|
||||
}
|
||||
@ -79,15 +86,40 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, std::abs);
|
||||
}
|
||||
|
||||
typedef void (*UnaryOp)(const float* in, int n, float* out);
|
||||
|
||||
inline TfLiteStatus EvalNumericOptimized(TfLiteContext* context,
|
||||
TfLiteNode* node, UnaryOp op) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const int64_t num_elements = NumElements(input);
|
||||
const float* in_data = GetTensorData<float>(input);
|
||||
float* out_data = GetTensorData<float>(output);
|
||||
op(in_data, num_elements, out_data);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
return EvalNumericOptimized(context, node, optimized_ops::Sin<float>);
|
||||
}
|
||||
return EvalNumeric(context, node, std::sin);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
return EvalNumericOptimized(context, node, optimized_ops::Cos<float>);
|
||||
}
|
||||
return EvalNumeric(context, node, std::cos);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
return EvalNumericOptimized(context, node, optimized_ops::Log<float>);
|
||||
}
|
||||
return EvalNumeric(context, node, std::log);
|
||||
}
|
||||
|
||||
@ -118,30 +150,60 @@ TfLiteRegistration* Register_ABS() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_SIN() {
|
||||
TfLiteRegistration* Register_SIN_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::SinEval};
|
||||
elementwise::SinEval<elementwise::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_COS() {
|
||||
TfLiteRegistration* Register_SIN_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::CosEval};
|
||||
elementwise::SinEval<elementwise::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LOG() {
|
||||
TfLiteRegistration* Register_SIN() { return Register_SIN_GENERIC_OPT(); }
|
||||
|
||||
TfLiteRegistration* Register_COS_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::LogEval};
|
||||
elementwise::CosEval<elementwise::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_COS_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::CosEval<elementwise::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_COS() { return Register_COS_GENERIC_OPT(); }
|
||||
|
||||
TfLiteRegistration* Register_LOG_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::LogEval<elementwise::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LOG_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::LogEval<elementwise::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LOG() { return Register_LOG_GENERIC_OPT(); }
|
||||
|
||||
TfLiteRegistration* Register_SQRT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -29,6 +32,7 @@ namespace exp {
|
||||
// This file has reference implementation of Exp.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
struct ExpContext {
|
||||
@ -53,7 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
ExpContext op_context(context, node);
|
||||
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
#define TF_LITE_EXP(kernel_type, data_type) \
|
||||
kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \
|
||||
NumElements(op_context.input), \
|
||||
@ -61,18 +66,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128.
|
||||
if (kernel_type == kReference) {
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_EXP(reference_ops, float);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Type %d is currently not supported by Exp.",
|
||||
op_context.input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_EXP(reference_ops, float);
|
||||
} else {
|
||||
TF_LITE_EXP(optimized_ops, float);
|
||||
}
|
||||
#undef TF_LITE_EXP
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Type %d is currently not supported by Exp.",
|
||||
op_context.input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -84,8 +89,13 @@ TfLiteRegistration* Register_EXP_REF() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
// TODO(kanlig): add optimized implementation of Exp.
|
||||
TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); }
|
||||
TfLiteRegistration* Register_EXP_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare,
|
||||
exp::Eval<exp::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_EXP() { return Register_EXP_GENERIC_OPT(); }
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
@ -3917,6 +3917,42 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
|
||||
output_map.array() = input_map.array().template cast<DstT>();
|
||||
}
|
||||
|
||||
template <typename T, typename ScalarOp>
|
||||
inline void ElementWise(const T* input_data, const int buffer_size,
|
||||
T* output_data) {
|
||||
auto input_map = VectorMap<const T>(input_data, buffer_size, 1);
|
||||
auto output_map = VectorMap<T>(output_data, buffer_size, 1);
|
||||
output_map.array() = input_map.array().unaryExpr(ScalarOp());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Cos(const T* input_data, const int buffer_size, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Cos");
|
||||
ElementWise<T, Eigen::internal::scalar_cos_op<T>>(input_data, buffer_size,
|
||||
output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Exp(const T* input_data, const int buffer_size, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Exp");
|
||||
return ElementWise<T, Eigen::internal::scalar_exp_op<T>>(
|
||||
input_data, buffer_size, output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Log(const T* input_data, const int buffer_size, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Log");
|
||||
ElementWise<T, Eigen::internal::scalar_log_op<T>>(input_data, buffer_size,
|
||||
output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Sin(const T* input_data, const int buffer_size, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Sin");
|
||||
ElementWise<T, Eigen::internal::scalar_sin_op<T>>(input_data, buffer_size,
|
||||
output_data);
|
||||
}
|
||||
|
||||
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Floor");
|
||||
|
@ -78,9 +78,9 @@ TfLiteRegistration* Register_SPLIT();
|
||||
TfLiteRegistration* Register_SPLIT_V();
|
||||
TfLiteRegistration* Register_SQUEEZE();
|
||||
TfLiteRegistration* Register_STRIDED_SLICE_REF();
|
||||
TfLiteRegistration* Register_EXP();
|
||||
TfLiteRegistration* Register_EXP_REF();
|
||||
TfLiteRegistration* Register_TOPK_V2();
|
||||
TfLiteRegistration* Register_LOG();
|
||||
TfLiteRegistration* Register_LOG_REF();
|
||||
TfLiteRegistration* Register_LOG_SOFTMAX_REF();
|
||||
TfLiteRegistration* Register_CAST();
|
||||
TfLiteRegistration* Register_DEQUANTIZE();
|
||||
@ -103,7 +103,7 @@ TfLiteRegistration* Register_REDUCE_MIN();
|
||||
TfLiteRegistration* Register_REDUCE_ANY();
|
||||
TfLiteRegistration* Register_SELECT();
|
||||
TfLiteRegistration* Register_SLICE_REF();
|
||||
TfLiteRegistration* Register_SIN();
|
||||
TfLiteRegistration* Register_SIN_REF();
|
||||
TfLiteRegistration* Register_TRANSPOSECONV_REF();
|
||||
TfLiteRegistration* Register_EXPAND_DIMS();
|
||||
TfLiteRegistration* Register_SPARSE_TO_DENSE();
|
||||
@ -229,9 +229,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF());
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP_REF());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
|
||||
AddBuiltin(BuiltinOperator_LOG, Register_LOG());
|
||||
AddBuiltin(BuiltinOperator_LOG, Register_LOG_REF());
|
||||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF());
|
||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||
@ -250,7 +250,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF());
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN_REF());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF());
|
||||
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
||||
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
|
||||
|
Loading…
Reference in New Issue
Block a user