parent
0c3bae3842
commit
0ee4bc62dd
@ -14,9 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -27,11 +25,6 @@ namespace builtin {
|
|||||||
namespace elementwise {
|
namespace elementwise {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
enum KernelType {
|
|
||||||
kReference,
|
|
||||||
kGenericOptimized,
|
|
||||||
};
|
|
||||||
|
|
||||||
bool IsNumericSupportedType(const TfLiteType type) {
|
bool IsNumericSupportedType(const TfLiteType type) {
|
||||||
return type == kTfLiteFloat32;
|
return type == kTfLiteFloat32;
|
||||||
}
|
}
|
||||||
@ -86,40 +79,15 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return EvalNumeric(context, node, std::abs);
|
return EvalNumeric(context, node, std::abs);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef void (*UnaryOp)(const float* in, int n, float* out);
|
|
||||||
|
|
||||||
inline TfLiteStatus EvalNumericOptimized(TfLiteContext* context,
|
|
||||||
TfLiteNode* node, UnaryOp op) {
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
|
||||||
const int64_t num_elements = NumElements(input);
|
|
||||||
const float* in_data = GetTensorData<float>(input);
|
|
||||||
float* out_data = GetTensorData<float>(output);
|
|
||||||
op(in_data, num_elements, out_data);
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
|
||||||
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
if (kernel_type == kGenericOptimized) {
|
|
||||||
return EvalNumericOptimized(context, node, optimized_ops::Sin<float>);
|
|
||||||
}
|
|
||||||
return EvalNumeric(context, node, std::sin);
|
return EvalNumeric(context, node, std::sin);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
|
||||||
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
if (kernel_type == kGenericOptimized) {
|
|
||||||
return EvalNumericOptimized(context, node, optimized_ops::Cos<float>);
|
|
||||||
}
|
|
||||||
return EvalNumeric(context, node, std::cos);
|
return EvalNumeric(context, node, std::cos);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
|
||||||
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
if (kernel_type == kGenericOptimized) {
|
|
||||||
return EvalNumericOptimized(context, node, optimized_ops::Log<float>);
|
|
||||||
}
|
|
||||||
return EvalNumeric(context, node, std::log);
|
return EvalNumeric(context, node, std::log);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,60 +118,30 @@ TfLiteRegistration* Register_ABS() {
|
|||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_SIN_REF() {
|
TfLiteRegistration* Register_SIN() {
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
/*init=*/nullptr, /*free=*/nullptr,
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
elementwise::SinEval<elementwise::kReference>};
|
elementwise::SinEval};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_SIN_GENERIC_OPT() {
|
TfLiteRegistration* Register_COS() {
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
/*init=*/nullptr, /*free=*/nullptr,
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
elementwise::SinEval<elementwise::kGenericOptimized>};
|
elementwise::CosEval};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_SIN() { return Register_SIN_GENERIC_OPT(); }
|
TfLiteRegistration* Register_LOG() {
|
||||||
|
|
||||||
TfLiteRegistration* Register_COS_REF() {
|
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
/*init=*/nullptr, /*free=*/nullptr,
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||||
elementwise::CosEval<elementwise::kReference>};
|
elementwise::LogEval};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_COS_GENERIC_OPT() {
|
|
||||||
static TfLiteRegistration r = {
|
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
elementwise::CosEval<elementwise::kGenericOptimized>};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_COS() { return Register_COS_GENERIC_OPT(); }
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_LOG_REF() {
|
|
||||||
static TfLiteRegistration r = {
|
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
elementwise::LogEval<elementwise::kReference>};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_LOG_GENERIC_OPT() {
|
|
||||||
static TfLiteRegistration r = {
|
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
|
||||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
|
||||||
elementwise::LogEval<elementwise::kGenericOptimized>};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_LOG() { return Register_LOG_GENERIC_OPT(); }
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_SQRT() {
|
TfLiteRegistration* Register_SQRT() {
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {
|
||||||
/*init=*/nullptr, /*free=*/nullptr,
|
/*init=*/nullptr, /*free=*/nullptr,
|
||||||
|
@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -32,7 +29,6 @@ namespace exp {
|
|||||||
// This file has reference implementation of Exp.
|
// This file has reference implementation of Exp.
|
||||||
enum KernelType {
|
enum KernelType {
|
||||||
kReference,
|
kReference,
|
||||||
kGenericOptimized,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ExpContext {
|
struct ExpContext {
|
||||||
@ -57,8 +53,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
ExpContext op_context(context, node);
|
ExpContext op_context(context, node);
|
||||||
switch (op_context.input->type) {
|
|
||||||
case kTfLiteFloat32:
|
|
||||||
#define TF_LITE_EXP(kernel_type, data_type) \
|
#define TF_LITE_EXP(kernel_type, data_type) \
|
||||||
kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \
|
kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \
|
||||||
NumElements(op_context.input), \
|
NumElements(op_context.input), \
|
||||||
@ -66,18 +61,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128.
|
// TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128.
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
TF_LITE_EXP(reference_ops, float);
|
switch (op_context.input->type) {
|
||||||
} else {
|
case kTfLiteFloat32:
|
||||||
TF_LITE_EXP(optimized_ops, float);
|
TF_LITE_EXP(reference_ops, float);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
context->ReportError(context,
|
||||||
|
"Type %d is currently not supported by Exp.",
|
||||||
|
op_context.input->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#undef TF_LITE_EXP
|
#undef TF_LITE_EXP
|
||||||
break;
|
|
||||||
default:
|
|
||||||
context->ReportError(context,
|
|
||||||
"Type %d is currently not supported by Exp.",
|
|
||||||
op_context.input->type);
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,13 +84,8 @@ TfLiteRegistration* Register_EXP_REF() {
|
|||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_EXP_GENERIC_OPT() {
|
// TODO(kanlig): add optimized implementation of Exp.
|
||||||
static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare,
|
TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); }
|
||||||
exp::Eval<exp::kGenericOptimized>};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_EXP() { return Register_EXP_GENERIC_OPT(); }
|
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
@ -3917,42 +3917,6 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
|
|||||||
output_map.array() = input_map.array().template cast<DstT>();
|
output_map.array() = input_map.array().template cast<DstT>();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename ScalarOp>
|
|
||||||
inline void ElementWise(const T* input_data, const int buffer_size,
|
|
||||||
T* output_data) {
|
|
||||||
auto input_map = VectorMap<const T>(input_data, buffer_size, 1);
|
|
||||||
auto output_map = VectorMap<T>(output_data, buffer_size, 1);
|
|
||||||
output_map.array() = input_map.array().unaryExpr(ScalarOp());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void Cos(const T* input_data, const int buffer_size, T* output_data) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("Cos");
|
|
||||||
ElementWise<T, Eigen::internal::scalar_cos_op<T>>(input_data, buffer_size,
|
|
||||||
output_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void Exp(const T* input_data, const int buffer_size, T* output_data) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("Exp");
|
|
||||||
return ElementWise<T, Eigen::internal::scalar_exp_op<T>>(
|
|
||||||
input_data, buffer_size, output_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void Log(const T* input_data, const int buffer_size, T* output_data) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("Log");
|
|
||||||
ElementWise<T, Eigen::internal::scalar_log_op<T>>(input_data, buffer_size,
|
|
||||||
output_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void Sin(const T* input_data, const int buffer_size, T* output_data) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("Sin");
|
|
||||||
ElementWise<T, Eigen::internal::scalar_sin_op<T>>(input_data, buffer_size,
|
|
||||||
output_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
|
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
|
||||||
const RuntimeShape& output_shape, float* output_data) {
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
gemmlowp::ScopedProfilingLabel label("Floor");
|
gemmlowp::ScopedProfilingLabel label("Floor");
|
||||||
|
@ -78,9 +78,9 @@ TfLiteRegistration* Register_SPLIT();
|
|||||||
TfLiteRegistration* Register_SPLIT_V();
|
TfLiteRegistration* Register_SPLIT_V();
|
||||||
TfLiteRegistration* Register_SQUEEZE();
|
TfLiteRegistration* Register_SQUEEZE();
|
||||||
TfLiteRegistration* Register_STRIDED_SLICE_REF();
|
TfLiteRegistration* Register_STRIDED_SLICE_REF();
|
||||||
TfLiteRegistration* Register_EXP_REF();
|
TfLiteRegistration* Register_EXP();
|
||||||
TfLiteRegistration* Register_TOPK_V2();
|
TfLiteRegistration* Register_TOPK_V2();
|
||||||
TfLiteRegistration* Register_LOG_REF();
|
TfLiteRegistration* Register_LOG();
|
||||||
TfLiteRegistration* Register_LOG_SOFTMAX_REF();
|
TfLiteRegistration* Register_LOG_SOFTMAX_REF();
|
||||||
TfLiteRegistration* Register_CAST();
|
TfLiteRegistration* Register_CAST();
|
||||||
TfLiteRegistration* Register_DEQUANTIZE();
|
TfLiteRegistration* Register_DEQUANTIZE();
|
||||||
@ -103,7 +103,7 @@ TfLiteRegistration* Register_REDUCE_MIN();
|
|||||||
TfLiteRegistration* Register_REDUCE_ANY();
|
TfLiteRegistration* Register_REDUCE_ANY();
|
||||||
TfLiteRegistration* Register_SELECT();
|
TfLiteRegistration* Register_SELECT();
|
||||||
TfLiteRegistration* Register_SLICE_REF();
|
TfLiteRegistration* Register_SLICE_REF();
|
||||||
TfLiteRegistration* Register_SIN_REF();
|
TfLiteRegistration* Register_SIN();
|
||||||
TfLiteRegistration* Register_TRANSPOSECONV_REF();
|
TfLiteRegistration* Register_TRANSPOSECONV_REF();
|
||||||
TfLiteRegistration* Register_EXPAND_DIMS();
|
TfLiteRegistration* Register_EXPAND_DIMS();
|
||||||
TfLiteRegistration* Register_SPARSE_TO_DENSE();
|
TfLiteRegistration* Register_SPARSE_TO_DENSE();
|
||||||
@ -229,9 +229,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
|
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
|
||||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF());
|
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF());
|
||||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP_REF());
|
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
|
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
|
||||||
AddBuiltin(BuiltinOperator_LOG, Register_LOG_REF());
|
AddBuiltin(BuiltinOperator_LOG, Register_LOG());
|
||||||
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF());
|
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF());
|
||||||
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
|
||||||
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
|
||||||
@ -250,7 +250,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
||||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
|
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
|
||||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF());
|
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF());
|
||||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN_REF());
|
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF());
|
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF());
|
||||||
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
||||||
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
|
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
|
||||||
|
Loading…
Reference in New Issue
Block a user