- Fix type checking in elementwise.cc
- Update error messages for some Abs Cast Ceil Cos Sin Not Square Sqrt RSqrt Log PiperOrigin-RevId: 317807251 Change-Id: I2a4f359f04346551eda5a382b25f34bab2c73dc7
This commit is contained in:
parent
51ee63247d
commit
e5023a1738
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -67,8 +68,8 @@ void copyCast(const std::complex<float>* in, std::complex<float>* out,
|
||||
}
|
||||
|
||||
template <typename FromT>
|
||||
TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
|
||||
int num_elements) {
|
||||
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
||||
TfLiteTensor* out, int num_elements) {
|
||||
switch (out->type) {
|
||||
case kTfLiteInt64:
|
||||
copyCast(in, out->data.i64, num_elements);
|
||||
@ -91,7 +92,7 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
|
||||
break;
|
||||
default:
|
||||
// Unsupported type.
|
||||
return kTfLiteError;
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -103,22 +104,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
|
||||
switch (input->type) {
|
||||
case kTfLiteInt64:
|
||||
return copyToTensor(input->data.i64, output, num_elements);
|
||||
return copyToTensor(context, input->data.i64, output, num_elements);
|
||||
case kTfLiteInt32:
|
||||
return copyToTensor(input->data.i32, output, num_elements);
|
||||
return copyToTensor(context, input->data.i32, output, num_elements);
|
||||
case kTfLiteUInt8:
|
||||
return copyToTensor(input->data.uint8, output, num_elements);
|
||||
return copyToTensor(context, input->data.uint8, output, num_elements);
|
||||
case kTfLiteFloat32:
|
||||
return copyToTensor(GetTensorData<float>(input), output, num_elements);
|
||||
return copyToTensor(context, GetTensorData<float>(input), output,
|
||||
num_elements);
|
||||
case kTfLiteBool:
|
||||
return copyToTensor(input->data.b, output, num_elements);
|
||||
return copyToTensor(context, input->data.b, output, num_elements);
|
||||
case kTfLiteComplex64:
|
||||
return copyToTensor(
|
||||
reinterpret_cast<std::complex<float>*>(input->data.c64), output,
|
||||
num_elements);
|
||||
context, reinterpret_cast<std::complex<float>*>(input->data.c64),
|
||||
output, num_elements);
|
||||
default:
|
||||
// Unsupported type.
|
||||
return kTfLiteError;
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast");
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -41,6 +42,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
if (input->type != kTfLiteFloat32) {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Ceil");
|
||||
}
|
||||
|
||||
optimized_ops::Ceil(GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -39,17 +40,15 @@ bool IsLogicalSupportedType(const TfLiteType type) {
|
||||
}
|
||||
|
||||
typedef bool (*IsSupportedType)(TfLiteType);
|
||||
template <IsSupportedType>
|
||||
template <IsSupportedType is_supported_type, const char* op_name>
|
||||
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
if (!IsSupportedType(input->type)) {
|
||||
context->ReportError(context, "Current data type %d is not supported.",
|
||||
input->type);
|
||||
return kTfLiteError;
|
||||
if (!is_supported_type(input->type)) {
|
||||
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
|
||||
}
|
||||
return context->ResizeTensor(context, output,
|
||||
TfLiteIntArrayCopy(input->dims));
|
||||
@ -112,13 +111,23 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalLogical(context, node, [](bool v) { return !v; });
|
||||
}
|
||||
|
||||
constexpr char kAbsName[] = "Abs";
|
||||
constexpr char kSinName[] = "Sin";
|
||||
constexpr char kCosName[] = "Cos";
|
||||
constexpr char kLogName[] = "Log";
|
||||
constexpr char kSqrtName[] = "Sqrt";
|
||||
constexpr char kRsqrtName[] = "Rsqrt";
|
||||
constexpr char kSquareName[] = "Square";
|
||||
constexpr char kNotName[] = "Not";
|
||||
|
||||
} // namespace
|
||||
} // namespace elementwise
|
||||
|
||||
TfLiteRegistration* Register_ABS() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kAbsName>,
|
||||
elementwise::AbsEval};
|
||||
return &r;
|
||||
}
|
||||
@ -126,7 +135,8 @@ TfLiteRegistration* Register_ABS() {
|
||||
TfLiteRegistration* Register_SIN() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kSinName>,
|
||||
elementwise::SinEval};
|
||||
return &r;
|
||||
}
|
||||
@ -134,7 +144,8 @@ TfLiteRegistration* Register_SIN() {
|
||||
TfLiteRegistration* Register_COS() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kCosName>,
|
||||
elementwise::CosEval};
|
||||
return &r;
|
||||
}
|
||||
@ -142,7 +153,8 @@ TfLiteRegistration* Register_COS() {
|
||||
TfLiteRegistration* Register_LOG() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kLogName>,
|
||||
elementwise::LogEval};
|
||||
return &r;
|
||||
}
|
||||
@ -150,7 +162,8 @@ TfLiteRegistration* Register_LOG() {
|
||||
TfLiteRegistration* Register_SQRT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kSqrtName>,
|
||||
elementwise::SqrtEval};
|
||||
return &r;
|
||||
}
|
||||
@ -158,7 +171,8 @@ TfLiteRegistration* Register_SQRT() {
|
||||
TfLiteRegistration* Register_RSQRT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kRsqrtName>,
|
||||
elementwise::RsqrtEval};
|
||||
return &r;
|
||||
}
|
||||
@ -166,7 +180,8 @@ TfLiteRegistration* Register_RSQRT() {
|
||||
TfLiteRegistration* Register_SQUARE() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
|
||||
elementwise::kSquareName>,
|
||||
elementwise::SquareEval};
|
||||
return &r;
|
||||
}
|
||||
@ -174,7 +189,8 @@ TfLiteRegistration* Register_SQUARE() {
|
||||
TfLiteRegistration* Register_LOGICAL_NOT() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
|
||||
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType,
|
||||
elementwise::kNotName>,
|
||||
elementwise::LogicalNotEval};
|
||||
return &r;
|
||||
}
|
||||
|
@ -44,6 +44,15 @@ inline void InfiniteLoop() {
|
||||
fprintf(stderr, "%s", (x)); \
|
||||
} while (0)
|
||||
|
||||
// Report Error for unsupported type by op 'op_name' and returns kTfLiteError.
|
||||
#define TF_LITE_UNSUPPORTED_TYPE(context, type, op_name) \
|
||||
do { \
|
||||
TF_LITE_KERNEL_LOG((context), "%s:%d Type %s is unsupported by op %s.", \
|
||||
__FILE__, __LINE__, TfLiteTypeGetName(type), \
|
||||
(op_name)); \
|
||||
return kTfLiteError; \
|
||||
} while (0)
|
||||
|
||||
#define TFLITE_ABORT abort()
|
||||
|
||||
#endif // TF_LITE_MCU_DEBUG_LOG
|
||||
|
Loading…
Reference in New Issue
Block a user