- Fix type checking in elementwise.cc

- Update error messages for some
Abs
Cast
Ceil
Cos
Sin
Not
Square
Sqrt
RSqrt
Log

PiperOrigin-RevId: 317807251
Change-Id: I2a4f359f04346551eda5a382b25f34bab2c73dc7
This commit is contained in:
Karim Nosir 2020-06-22 23:10:17 -07:00 committed by TensorFlower Gardener
parent 51ee63247d
commit e5023a1738
4 changed files with 55 additions and 24 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
@ -67,8 +68,8 @@ void copyCast(const std::complex<float>* in, std::complex<float>* out,
}
template <typename FromT>
TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
int num_elements) {
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
TfLiteTensor* out, int num_elements) {
switch (out->type) {
case kTfLiteInt64:
copyCast(in, out->data.i64, num_elements);
@ -91,7 +92,7 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
break;
default:
// Unsupported type.
return kTfLiteError;
TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
}
return kTfLiteOk;
}
@ -103,22 +104,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
switch (input->type) {
case kTfLiteInt64:
return copyToTensor(input->data.i64, output, num_elements);
return copyToTensor(context, input->data.i64, output, num_elements);
case kTfLiteInt32:
return copyToTensor(input->data.i32, output, num_elements);
return copyToTensor(context, input->data.i32, output, num_elements);
case kTfLiteUInt8:
return copyToTensor(input->data.uint8, output, num_elements);
return copyToTensor(context, input->data.uint8, output, num_elements);
case kTfLiteFloat32:
return copyToTensor(GetTensorData<float>(input), output, num_elements);
return copyToTensor(context, GetTensorData<float>(input), output,
num_elements);
case kTfLiteBool:
return copyToTensor(input->data.b, output, num_elements);
return copyToTensor(context, input->data.b, output, num_elements);
case kTfLiteComplex64:
return copyToTensor(
reinterpret_cast<std::complex<float>*>(input->data.c64), output,
num_elements);
context, reinterpret_cast<std::complex<float>*>(input->data.c64),
output, num_elements);
default:
// Unsupported type.
return kTfLiteError;
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast");
}
return kTfLiteOk;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
@ -41,6 +42,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (input->type != kTfLiteFloat32) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Ceil");
}
optimized_ops::Ceil(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
@ -39,17 +40,15 @@ bool IsLogicalSupportedType(const TfLiteType type) {
}
typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType>
template <IsSupportedType is_supported_type, const char* op_name>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (!IsSupportedType(input->type)) {
context->ReportError(context, "Current data type %d is not supported.",
input->type);
return kTfLiteError;
if (!is_supported_type(input->type)) {
TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
}
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
@ -112,13 +111,23 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
constexpr char kAbsName[] = "Abs";
constexpr char kSinName[] = "Sin";
constexpr char kCosName[] = "Cos";
constexpr char kLogName[] = "Log";
constexpr char kSqrtName[] = "Sqrt";
constexpr char kRsqrtName[] = "Rsqrt";
constexpr char kSquareName[] = "Square";
constexpr char kNotName[] = "Not";
} // namespace
} // namespace elementwise
TfLiteRegistration* Register_ABS() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kAbsName>,
elementwise::AbsEval};
return &r;
}
@ -126,7 +135,8 @@ TfLiteRegistration* Register_ABS() {
TfLiteRegistration* Register_SIN() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kSinName>,
elementwise::SinEval};
return &r;
}
@ -134,7 +144,8 @@ TfLiteRegistration* Register_SIN() {
TfLiteRegistration* Register_COS() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kCosName>,
elementwise::CosEval};
return &r;
}
@ -142,7 +153,8 @@ TfLiteRegistration* Register_COS() {
TfLiteRegistration* Register_LOG() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kLogName>,
elementwise::LogEval};
return &r;
}
@ -150,7 +162,8 @@ TfLiteRegistration* Register_LOG() {
TfLiteRegistration* Register_SQRT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kSqrtName>,
elementwise::SqrtEval};
return &r;
}
@ -158,7 +171,8 @@ TfLiteRegistration* Register_SQRT() {
TfLiteRegistration* Register_RSQRT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kRsqrtName>,
elementwise::RsqrtEval};
return &r;
}
@ -166,7 +180,8 @@ TfLiteRegistration* Register_RSQRT() {
TfLiteRegistration* Register_SQUARE() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType,
elementwise::kSquareName>,
elementwise::SquareEval};
return &r;
}
@ -174,7 +189,8 @@ TfLiteRegistration* Register_SQUARE() {
TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType,
elementwise::kNotName>,
elementwise::LogicalNotEval};
return &r;
}

View File

@ -44,6 +44,15 @@ inline void InfiniteLoop() {
fprintf(stderr, "%s", (x)); \
} while (0)
// Report Error for unsupported type by op 'op_name' and returns kTfLiteError.
#define TF_LITE_UNSUPPORTED_TYPE(context, type, op_name) \
do { \
TF_LITE_KERNEL_LOG((context), "%s:%d Type %s is unsupported by op %s.", \
__FILE__, __LINE__, TfLiteTypeGetName(type), \
(op_name)); \
return kTfLiteError; \
} while (0)
#define TFLITE_ABORT abort()
#endif // TF_LITE_MCU_DEBUG_LOG