Add cast op support to the NNAPI delegate
PiperOrigin-RevId: 257774029
This commit is contained in:
parent
cbe94fd66e
commit
323564b496
@ -1729,6 +1729,21 @@ class NNAPIDelegateKernel {
|
||||
return BasicMappingFn<ANEURALNETWORKS_MINIMUM>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinCast: {
|
||||
const TfLiteType input_type =
|
||||
context->tensors[node->inputs->data[0]].type;
|
||||
const TfLiteType output_type =
|
||||
context->tensors[node->outputs->data[0]].type;
|
||||
auto is_supported_tensor_type = [](const TfLiteType& type) {
|
||||
return (type == kTfLiteFloat32 || type == kTfLiteInt32 ||
|
||||
type == kTfLiteUInt8);
|
||||
};
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
is_supported_tensor_type(input_type) &&
|
||||
is_supported_tensor_type(output_type)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_CAST>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinPrelu:
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
||||
if (!IsFloatOrUint8Operator(context, node)) {
|
||||
|
@ -669,6 +669,7 @@ cc_test(
|
||||
name = "cast_test",
|
||||
size = "small",
|
||||
srcs = ["cast_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
|
@ -43,7 +43,23 @@ class CastOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(CastOpModel, CastIntToFloat) {
|
||||
TEST(CastOpModel, CastInt32ToFloat) {
|
||||
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt32) {
|
||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt64ToFloat) {
|
||||
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
@ -51,11 +67,11 @@ TEST(CastOpModel, CastIntToFloat) {
|
||||
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt) {
|
||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
|
||||
TEST(CastOpModel, CastFloatToInt64) {
|
||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT64, {3, 2}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int>(m.output()),
|
||||
EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
|
||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||
}
|
||||
|
||||
@ -75,6 +91,38 @@ TEST(CastOpModel, CastBoolToFloat) {
|
||||
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToUInt8) {
|
||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
|
||||
ElementsAreArray({100, 1, 0, 0, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastUInt8ToFloat) {
|
||||
CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
|
||||
m.PopulateTensor<uint8_t>(m.input(), {123, 0, 1, 2, 3, 4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt32ToUInt8) {
|
||||
CastOpModel m({TensorType_INT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
|
||||
m.PopulateTensor<int32_t>(m.input(), {100, 1, 200, 2, 255, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
|
||||
ElementsAreArray({100, 1, 200, 2, 255, 3}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastUInt8ToInt32) {
|
||||
CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_INT32, {3, 2}});
|
||||
m.PopulateTensor<uint8_t>(m.input(), {100, 1, 200, 2, 255, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||
ElementsAreArray({100, 1, 200, 2, 255, 3}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastComplex64ToFloat) {
|
||||
CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<std::complex<float>>(
|
||||
|
@ -93,6 +93,7 @@ enum {
|
||||
ANEURALNETWORKS_ARGMAX = 39,
|
||||
ANEURALNETWORKS_ARGMIN = 40,
|
||||
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
|
||||
ANEURALNETWORKS_CAST = 45,
|
||||
ANEURALNETWORKS_EQUAL = 48,
|
||||
ANEURALNETWORKS_EXP = 49,
|
||||
ANEURALNETWORKS_EXPAND_DIMS = 50,
|
||||
|
Loading…
Reference in New Issue
Block a user