Add optimized versions of numeric unary ops cos, exp, log, sin

PiperOrigin-RevId: 266885222
This commit is contained in:
A. Unique TensorFlower 2019-09-03 02:25:45 -07:00 committed by TensorFlower Gardener
parent 218836e14c
commit 185a465225
4 changed files with 133 additions and 25 deletions

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cmath> #include <cmath>
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -25,6 +27,11 @@ namespace builtin {
namespace elementwise { namespace elementwise {
namespace { namespace {
enum KernelType {
kReference,
kGenericOptimized,
};
bool IsNumericSupportedType(const TfLiteType type) { bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32; return type == kTfLiteFloat32;
} }
@ -79,15 +86,40 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::abs); return EvalNumeric(context, node, std::abs);
} }
typedef void (*UnaryOp)(const float* in, int n, float* out);
inline TfLiteStatus EvalNumericOptimized(TfLiteContext* context,
TfLiteNode* node, UnaryOp op) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const int64_t num_elements = NumElements(input);
const float* in_data = GetTensorData<float>(input);
float* out_data = GetTensorData<float>(output);
op(in_data, num_elements, out_data);
return kTfLiteOk;
}
template <KernelType kernel_type>
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
if (kernel_type == kGenericOptimized) {
return EvalNumericOptimized(context, node, optimized_ops::Sin<float>);
}
return EvalNumeric(context, node, std::sin); return EvalNumeric(context, node, std::sin);
} }
template <KernelType kernel_type>
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
if (kernel_type == kGenericOptimized) {
return EvalNumericOptimized(context, node, optimized_ops::Cos<float>);
}
return EvalNumeric(context, node, std::cos); return EvalNumeric(context, node, std::cos);
} }
template <KernelType kernel_type>
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
if (kernel_type == kGenericOptimized) {
return EvalNumericOptimized(context, node, optimized_ops::Log<float>);
}
return EvalNumeric(context, node, std::log); return EvalNumeric(context, node, std::log);
} }
@ -118,30 +150,60 @@ TfLiteRegistration* Register_ABS() {
return &r; return &r;
} }
TfLiteRegistration* Register_SIN() { TfLiteRegistration* Register_SIN_REF() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SinEval}; elementwise::SinEval<elementwise::kReference>};
return &r; return &r;
} }
TfLiteRegistration* Register_COS() { TfLiteRegistration* Register_SIN_GENERIC_OPT() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::CosEval}; elementwise::SinEval<elementwise::kGenericOptimized>};
return &r; return &r;
} }
TfLiteRegistration* Register_LOG() { TfLiteRegistration* Register_SIN() { return Register_SIN_GENERIC_OPT(); }
TfLiteRegistration* Register_COS_REF() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval}; elementwise::CosEval<elementwise::kReference>};
return &r; return &r;
} }
TfLiteRegistration* Register_COS_GENERIC_OPT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::CosEval<elementwise::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_COS() { return Register_COS_GENERIC_OPT(); }
TfLiteRegistration* Register_LOG_REF() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval<elementwise::kReference>};
return &r;
}
TfLiteRegistration* Register_LOG_GENERIC_OPT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval<elementwise::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_LOG() { return Register_LOG_GENERIC_OPT(); }
TfLiteRegistration* Register_SQRT() { TfLiteRegistration* Register_SQRT() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string.h> #include <string.h>
#include <vector> #include <vector>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -29,6 +32,7 @@ namespace exp {
// This file has reference implementation of Exp. // This file has reference implementation of Exp.
enum KernelType { enum KernelType {
kReference, kReference,
kGenericOptimized,
}; };
struct ExpContext { struct ExpContext {
@ -53,7 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ExpContext op_context(context, node); ExpContext op_context(context, node);
switch (op_context.input->type) {
case kTfLiteFloat32:
#define TF_LITE_EXP(kernel_type, data_type) \ #define TF_LITE_EXP(kernel_type, data_type) \
kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \ kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \
NumElements(op_context.input), \ NumElements(op_context.input), \
@ -61,9 +66,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128.
if (kernel_type == kReference) { if (kernel_type == kReference) {
switch (op_context.input->type) {
case kTfLiteFloat32:
TF_LITE_EXP(reference_ops, float); TF_LITE_EXP(reference_ops, float);
} else {
TF_LITE_EXP(optimized_ops, float);
}
#undef TF_LITE_EXP
break; break;
default: default:
context->ReportError(context, context->ReportError(context,
@ -71,8 +78,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
op_context.input->type); op_context.input->type);
return kTfLiteError; return kTfLiteError;
} }
}
#undef TF_LITE_EXP
return kTfLiteOk; return kTfLiteOk;
} }
@ -84,8 +89,13 @@ TfLiteRegistration* Register_EXP_REF() {
return &r; return &r;
} }
// TODO(kanlig): add optimized implementation of Exp. TfLiteRegistration* Register_EXP_GENERIC_OPT() {
TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); } static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare,
exp::Eval<exp::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_EXP() { return Register_EXP_GENERIC_OPT(); }
} // namespace builtin } // namespace builtin
} // namespace ops } // namespace ops

View File

@ -3917,6 +3917,42 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
output_map.array() = input_map.array().template cast<DstT>(); output_map.array() = input_map.array().template cast<DstT>();
} }
template <typename T, typename ScalarOp>
inline void ElementWise(const T* input_data, const int buffer_size,
T* output_data) {
auto input_map = VectorMap<const T>(input_data, buffer_size, 1);
auto output_map = VectorMap<T>(output_data, buffer_size, 1);
output_map.array() = input_map.array().unaryExpr(ScalarOp());
}
template <typename T>
inline void Cos(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Cos");
ElementWise<T, Eigen::internal::scalar_cos_op<T>>(input_data, buffer_size,
output_data);
}
template <typename T>
inline void Exp(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Exp");
return ElementWise<T, Eigen::internal::scalar_exp_op<T>>(
input_data, buffer_size, output_data);
}
template <typename T>
inline void Log(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Log");
ElementWise<T, Eigen::internal::scalar_log_op<T>>(input_data, buffer_size,
output_data);
}
template <typename T>
inline void Sin(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Sin");
ElementWise<T, Eigen::internal::scalar_sin_op<T>>(input_data, buffer_size,
output_data);
}
inline void Floor(const RuntimeShape& input_shape, const float* input_data, inline void Floor(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) { const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Floor"); gemmlowp::ScopedProfilingLabel label("Floor");

View File

@ -78,9 +78,9 @@ TfLiteRegistration* Register_SPLIT();
TfLiteRegistration* Register_SPLIT_V(); TfLiteRegistration* Register_SPLIT_V();
TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_SQUEEZE();
TfLiteRegistration* Register_STRIDED_SLICE_REF(); TfLiteRegistration* Register_STRIDED_SLICE_REF();
TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_EXP_REF();
TfLiteRegistration* Register_TOPK_V2(); TfLiteRegistration* Register_TOPK_V2();
TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_REF();
TfLiteRegistration* Register_LOG_SOFTMAX_REF(); TfLiteRegistration* Register_LOG_SOFTMAX_REF();
TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE(); TfLiteRegistration* Register_DEQUANTIZE();
@ -103,7 +103,7 @@ TfLiteRegistration* Register_REDUCE_MIN();
TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE_REF(); TfLiteRegistration* Register_SLICE_REF();
TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_SIN_REF();
TfLiteRegistration* Register_TRANSPOSECONV_REF(); TfLiteRegistration* Register_TRANSPOSECONV_REF();
TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_EXPAND_DIMS();
TfLiteRegistration* Register_SPARSE_TO_DENSE(); TfLiteRegistration* Register_SPARSE_TO_DENSE();
@ -229,9 +229,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V()); AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF());
AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_EXP, Register_EXP_REF());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG, Register_LOG_REF());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF());
AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
@ -250,7 +250,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_NEG, Register_NEG());
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF()); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF());
AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_SIN, Register_SIN_REF());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF());
AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_TILE, Register_TILE());
AddBuiltin(BuiltinOperator_SUM, Register_SUM()); AddBuiltin(BuiltinOperator_SUM, Register_SUM());