Add cast op support to the NNAPI delegate

PiperOrigin-RevId: 257774029
This commit is contained in:
A. Unique TensorFlower 2019-07-12 03:19:54 -07:00 committed by TensorFlower Gardener
parent cbe94fd66e
commit 323564b496
4 changed files with 69 additions and 4 deletions

View File

@ -1729,6 +1729,21 @@ class NNAPIDelegateKernel {
return BasicMappingFn<ANEURALNETWORKS_MINIMUM>;
}
} break;
case kTfLiteBuiltinCast: {
const TfLiteType input_type =
context->tensors[node->inputs->data[0]].type;
const TfLiteType output_type =
context->tensors[node->outputs->data[0]].type;
auto is_supported_tensor_type = [](const TfLiteType& type) {
return (type == kTfLiteFloat32 || type == kTfLiteInt32 ||
type == kTfLiteUInt8);
};
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
is_supported_tensor_type(input_type) &&
is_supported_tensor_type(output_type)) {
return BasicMappingFn<ANEURALNETWORKS_CAST>;
}
} break;
case kTfLiteBuiltinPrelu:
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) {
if (!IsFloatOrUint8Operator(context, node)) {

View File

@ -669,6 +669,7 @@ cc_test(
name = "cast_test",
size = "small",
srcs = ["cast_test.cc"],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",

View File

@ -43,7 +43,23 @@ class CastOpModel : public SingleOpModel {
int output_;
};
TEST(CastOpModel, CastIntToFloat) {
TEST(CastOpModel, CastInt32ToFloat) {
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
}
TEST(CastOpModel, CastFloatToInt32) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
TEST(CastOpModel, CastInt64ToFloat) {
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
@ -51,11 +67,11 @@ TEST(CastOpModel, CastIntToFloat) {
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
}
TEST(CastOpModel, CastFloatToInt) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}});
TEST(CastOpModel, CastFloatToInt64) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT64, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int>(m.output()),
EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
@ -75,6 +91,38 @@ TEST(CastOpModel, CastBoolToFloat) {
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
}
TEST(CastOpModel, CastFloatToUInt8) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
m.Invoke();
EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
ElementsAreArray({100, 1, 0, 0, 1, 1}));
}
TEST(CastOpModel, CastUInt8ToFloat) {
CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
m.PopulateTensor<uint8_t>(m.input(), {123, 0, 1, 2, 3, 4});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f}));
}
TEST(CastOpModel, CastInt32ToUInt8) {
CastOpModel m({TensorType_INT32, {3, 2}}, {TensorType_UINT8, {3, 2}});
m.PopulateTensor<int32_t>(m.input(), {100, 1, 200, 2, 255, 3});
m.Invoke();
EXPECT_THAT(m.ExtractVector<uint8_t>(m.output()),
ElementsAreArray({100, 1, 200, 2, 255, 3}));
}
TEST(CastOpModel, CastUInt8ToInt32) {
CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_INT32, {3, 2}});
m.PopulateTensor<uint8_t>(m.input(), {100, 1, 200, 2, 255, 3});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
ElementsAreArray({100, 1, 200, 2, 255, 3}));
}
TEST(CastOpModel, CastComplex64ToFloat) {
CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<std::complex<float>>(

View File

@ -93,6 +93,7 @@ enum {
ANEURALNETWORKS_ARGMAX = 39,
ANEURALNETWORKS_ARGMIN = 40,
ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42,
ANEURALNETWORKS_CAST = 45,
ANEURALNETWORKS_EQUAL = 48,
ANEURALNETWORKS_EXP = 49,
ANEURALNETWORKS_EXPAND_DIMS = 50,