Add cast op support to the NNAPI delegate
PiperOrigin-RevId: 257774029
This commit is contained in:
parent
cbe94fd66e
commit
323564b496
@ -1729,6 +1729,21 @@ class NNAPIDelegateKernel {
|
|||||||
return BasicMappingFn<ANEURALNETWORKS_MINIMUM>;
|
return BasicMappingFn<ANEURALNETWORKS_MINIMUM>;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case kTfLiteBuiltinCast: {
|
||||||
|
const TfLiteType input_type =
|
||||||
|
context->tensors[node->inputs->data[0]].type;
|
||||||
|
const TfLiteType output_type =
|
||||||
|
context->tensors[node->outputs->data[0]].type;
|
||||||
|
auto is_supported_tensor_type = [](const TfLiteType& type) {
|
||||||
|
return (type == kTfLiteFloat32 || type == kTfLiteInt32 ||
|
||||||
|
type == kTfLiteUInt8);
|
||||||
|
};
|
||||||
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||||
|
is_supported_tensor_type(input_type) &&
|
||||||
|
is_supported_tensor_type(output_type)) {
|
||||||
|
return BasicMappingFn<ANEURALNETWORKS_CAST>;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case kTfLiteBuiltinPrelu:
|
case kTfLiteBuiltinPrelu:
|
||||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
|
||||||
if (!IsFloatOrUint8Operator(context, node)) {
|
if (!IsFloatOrUint8Operator(context, node)) {
|
||||||
|
@ -669,6 +669,7 @@ cc_test(
|
|||||||
name = "cast_test",
|
name = "cast_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["cast_test.cc"],
|
srcs = ["cast_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
@ -43,7 +43,23 @@ class CastOpModel : public SingleOpModel {
|
|||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(CastOpModel, CastIntToFloat) {
|
TEST(CastOpModel, CastInt32ToFloat) {
|
||||||
|
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||||
|
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||||
|
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastFloatToInt32) {
|
||||||
|
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
|
||||||
|
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||||
|
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastInt64ToFloat) {
|
||||||
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||||
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
@ -51,11 +67,11 @@ TEST(CastOpModel, CastIntToFloat) {
|
|||||||
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CastOpModel, CastFloatToInt) {
|
TEST(CastOpModel, CastFloatToInt64) {
|
||||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
|
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT64, {3, 2}});
|
||||||
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.ExtractVector<int>(m.output()),
|
EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
|
||||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,6 +91,38 @@ TEST(CastOpModel, CastBoolToFloat) {
|
|||||||
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
|
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastFloatToUInt8) {
|
||||||
|
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
|
||||||
|
m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
|
||||||
|
ElementsAreArray({100, 1, 0, 0, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastUInt8ToFloat) {
|
||||||
|
CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
|
||||||
|
m.PopulateTensor<uint8_t>(m.input(), {123, 0, 1, 2, 3, 4});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||||
|
ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastInt32ToUInt8) {
|
||||||
|
CastOpModel m({TensorType_INT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
|
||||||
|
m.PopulateTensor<int32_t>(m.input(), {100, 1, 200, 2, 255, 3});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
|
||||||
|
ElementsAreArray({100, 1, 200, 2, 255, 3}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CastOpModel, CastUInt8ToInt32) {
|
||||||
|
CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_INT32, {3, 2}});
|
||||||
|
m.PopulateTensor<uint8_t>(m.input(), {100, 1, 200, 2, 255, 3});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||||
|
ElementsAreArray({100, 1, 200, 2, 255, 3}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CastOpModel, CastComplex64ToFloat) {
|
TEST(CastOpModel, CastComplex64ToFloat) {
|
||||||
CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||||
m.PopulateTensor<std::complex<float>>(
|
m.PopulateTensor<std::complex<float>>(
|
||||||
|
@ -93,6 +93,7 @@ enum {
|
|||||||
ANEURALNETWORKS_ARGMAX = 39,
|
ANEURALNETWORKS_ARGMAX = 39,
|
||||||
ANEURALNETWORKS_ARGMIN = 40,
|
ANEURALNETWORKS_ARGMIN = 40,
|
||||||
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
|
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
|
||||||
|
ANEURALNETWORKS_CAST = 45,
|
||||||
ANEURALNETWORKS_EQUAL = 48,
|
ANEURALNETWORKS_EQUAL = 48,
|
||||||
ANEURALNETWORKS_EXP = 49,
|
ANEURALNETWORKS_EXP = 49,
|
||||||
ANEURALNETWORKS_EXPAND_DIMS = 50,
|
ANEURALNETWORKS_EXPAND_DIMS = 50,
|
||||||
|
Loading…
Reference in New Issue
Block a user