Automated rollback of commit 185a465225

PiperOrigin-RevId: 266946649
This commit is contained in:
A. Unique TensorFlower 2019-09-03 09:28:39 -07:00 committed by TensorFlower Gardener
parent 0c3bae3842
commit 0ee4bc62dd
4 changed files with 25 additions and 133 deletions

View File

@ -14,9 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cmath> #include <cmath>
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -27,11 +25,6 @@ namespace builtin {
namespace elementwise { namespace elementwise {
namespace { namespace {
enum KernelType {
kReference,
kGenericOptimized,
};
bool IsNumericSupportedType(const TfLiteType type) { bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32; return type == kTfLiteFloat32;
} }
@ -86,40 +79,15 @@ TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, std::abs); return EvalNumeric(context, node, std::abs);
} }
typedef void (*UnaryOp)(const float* in, int n, float* out);
inline TfLiteStatus EvalNumericOptimized(TfLiteContext* context,
TfLiteNode* node, UnaryOp op) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const int64_t num_elements = NumElements(input);
const float* in_data = GetTensorData<float>(input);
float* out_data = GetTensorData<float>(output);
op(in_data, num_elements, out_data);
return kTfLiteOk;
}
template <KernelType kernel_type>
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
if (kernel_type == kGenericOptimized) {
return EvalNumericOptimized(context, node, optimized_ops::Sin<float>);
}
return EvalNumeric(context, node, std::sin); return EvalNumeric(context, node, std::sin);
} }
template <KernelType kernel_type>
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
if (kernel_type == kGenericOptimized) {
return EvalNumericOptimized(context, node, optimized_ops::Cos<float>);
}
return EvalNumeric(context, node, std::cos); return EvalNumeric(context, node, std::cos);
} }
template <KernelType kernel_type>
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
if (kernel_type == kGenericOptimized) {
return EvalNumericOptimized(context, node, optimized_ops::Log<float>);
}
return EvalNumeric(context, node, std::log); return EvalNumeric(context, node, std::log);
} }
@ -150,60 +118,30 @@ TfLiteRegistration* Register_ABS() {
return &r; return &r;
} }
TfLiteRegistration* Register_SIN_REF() { TfLiteRegistration* Register_SIN() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SinEval<elementwise::kReference>}; elementwise::SinEval};
return &r; return &r;
} }
TfLiteRegistration* Register_SIN_GENERIC_OPT() { TfLiteRegistration* Register_COS() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SinEval<elementwise::kGenericOptimized>}; elementwise::CosEval};
return &r; return &r;
} }
TfLiteRegistration* Register_SIN() { return Register_SIN_GENERIC_OPT(); } TfLiteRegistration* Register_LOG() {
TfLiteRegistration* Register_COS_REF() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::CosEval<elementwise::kReference>}; elementwise::LogEval};
return &r; return &r;
} }
TfLiteRegistration* Register_COS_GENERIC_OPT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::CosEval<elementwise::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_COS() { return Register_COS_GENERIC_OPT(); }
TfLiteRegistration* Register_LOG_REF() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval<elementwise::kReference>};
return &r;
}
TfLiteRegistration* Register_LOG_GENERIC_OPT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval<elementwise::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_LOG() { return Register_LOG_GENERIC_OPT(); }
TfLiteRegistration* Register_SQRT() { TfLiteRegistration* Register_SQRT() {
static TfLiteRegistration r = { static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr, /*init=*/nullptr, /*free=*/nullptr,

View File

@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string.h> #include <string.h>
#include <vector> #include <vector>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -32,7 +29,6 @@ namespace exp {
// This file has reference implementation of Exp. // This file has reference implementation of Exp.
enum KernelType { enum KernelType {
kReference, kReference,
kGenericOptimized,
}; };
struct ExpContext { struct ExpContext {
@ -57,8 +53,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ExpContext op_context(context, node); ExpContext op_context(context, node);
switch (op_context.input->type) {
case kTfLiteFloat32:
#define TF_LITE_EXP(kernel_type, data_type) \ #define TF_LITE_EXP(kernel_type, data_type) \
kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \ kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \
NumElements(op_context.input), \ NumElements(op_context.input), \
@ -66,18 +61,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128.
if (kernel_type == kReference) { if (kernel_type == kReference) {
TF_LITE_EXP(reference_ops, float); switch (op_context.input->type) {
} else { case kTfLiteFloat32:
TF_LITE_EXP(optimized_ops, float); TF_LITE_EXP(reference_ops, float);
break;
default:
context->ReportError(context,
"Type %d is currently not supported by Exp.",
op_context.input->type);
return kTfLiteError;
}
} }
#undef TF_LITE_EXP #undef TF_LITE_EXP
break;
default:
context->ReportError(context,
"Type %d is currently not supported by Exp.",
op_context.input->type);
return kTfLiteError;
}
return kTfLiteOk; return kTfLiteOk;
} }
@ -89,13 +84,8 @@ TfLiteRegistration* Register_EXP_REF() {
return &r; return &r;
} }
TfLiteRegistration* Register_EXP_GENERIC_OPT() { // TODO(kanlig): add optimized implementation of Exp.
static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare, TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); }
exp::Eval<exp::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_EXP() { return Register_EXP_GENERIC_OPT(); }
} // namespace builtin } // namespace builtin
} // namespace ops } // namespace ops

View File

@ -3917,42 +3917,6 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
output_map.array() = input_map.array().template cast<DstT>(); output_map.array() = input_map.array().template cast<DstT>();
} }
template <typename T, typename ScalarOp>
inline void ElementWise(const T* input_data, const int buffer_size,
T* output_data) {
auto input_map = VectorMap<const T>(input_data, buffer_size, 1);
auto output_map = VectorMap<T>(output_data, buffer_size, 1);
output_map.array() = input_map.array().unaryExpr(ScalarOp());
}
template <typename T>
inline void Cos(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Cos");
ElementWise<T, Eigen::internal::scalar_cos_op<T>>(input_data, buffer_size,
output_data);
}
template <typename T>
inline void Exp(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Exp");
return ElementWise<T, Eigen::internal::scalar_exp_op<T>>(
input_data, buffer_size, output_data);
}
template <typename T>
inline void Log(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Log");
ElementWise<T, Eigen::internal::scalar_log_op<T>>(input_data, buffer_size,
output_data);
}
template <typename T>
inline void Sin(const T* input_data, const int buffer_size, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Sin");
ElementWise<T, Eigen::internal::scalar_sin_op<T>>(input_data, buffer_size,
output_data);
}
inline void Floor(const RuntimeShape& input_shape, const float* input_data, inline void Floor(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) { const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Floor"); gemmlowp::ScopedProfilingLabel label("Floor");

View File

@ -78,9 +78,9 @@ TfLiteRegistration* Register_SPLIT();
TfLiteRegistration* Register_SPLIT_V(); TfLiteRegistration* Register_SPLIT_V();
TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_SQUEEZE();
TfLiteRegistration* Register_STRIDED_SLICE_REF(); TfLiteRegistration* Register_STRIDED_SLICE_REF();
TfLiteRegistration* Register_EXP_REF(); TfLiteRegistration* Register_EXP();
TfLiteRegistration* Register_TOPK_V2(); TfLiteRegistration* Register_TOPK_V2();
TfLiteRegistration* Register_LOG_REF(); TfLiteRegistration* Register_LOG();
TfLiteRegistration* Register_LOG_SOFTMAX_REF(); TfLiteRegistration* Register_LOG_SOFTMAX_REF();
TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE(); TfLiteRegistration* Register_DEQUANTIZE();
@ -103,7 +103,7 @@ TfLiteRegistration* Register_REDUCE_MIN();
TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE_REF(); TfLiteRegistration* Register_SLICE_REF();
TfLiteRegistration* Register_SIN_REF(); TfLiteRegistration* Register_SIN();
TfLiteRegistration* Register_TRANSPOSECONV_REF(); TfLiteRegistration* Register_TRANSPOSECONV_REF();
TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_EXPAND_DIMS();
TfLiteRegistration* Register_SPARSE_TO_DENSE(); TfLiteRegistration* Register_SPARSE_TO_DENSE();
@ -229,9 +229,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V()); AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF());
AddBuiltin(BuiltinOperator_EXP, Register_EXP_REF()); AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
AddBuiltin(BuiltinOperator_LOG, Register_LOG_REF()); AddBuiltin(BuiltinOperator_LOG, Register_LOG());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF());
AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
@ -250,7 +250,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_NEG, Register_NEG());
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF()); AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF());
AddBuiltin(BuiltinOperator_SIN, Register_SIN_REF()); AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF());
AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_TILE, Register_TILE());
AddBuiltin(BuiltinOperator_SUM, Register_SUM()); AddBuiltin(BuiltinOperator_SUM, Register_SUM());