diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 30f754156a4..6d17fa260a8 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -1729,6 +1729,21 @@ class NNAPIDelegateKernel { return BasicMappingFn; } } break; + case kTfLiteBuiltinCast: { + const TfLiteType input_type = + context->tensors[node->inputs->data[0]].type; + const TfLiteType output_type = + context->tensors[node->outputs->data[0]].type; + auto is_supported_tensor_type = [](const TfLiteType& type) { + return (type == kTfLiteFloat32 || type == kTfLiteInt32 || + type == kTfLiteUInt8); + }; + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 && + is_supported_tensor_type(input_type) && + is_supported_tensor_type(output_type)) { + return BasicMappingFn; + } + } break; case kTfLiteBuiltinPrelu: if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { if (!IsFloatOrUint8Operator(context, node)) { diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 0e40f3fa0cd..f70ccf3a3d9 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -669,6 +669,7 @@ cc_test( name = "cast_test", size = "small", srcs = ["cast_test.cc"], + tags = ["tflite_nnapi"], deps = [ ":builtin_ops", ":test_main", diff --git a/tensorflow/lite/kernels/cast_test.cc b/tensorflow/lite/kernels/cast_test.cc index 6bad3d6e7b3..8f1cb44f1c9 100644 --- a/tensorflow/lite/kernels/cast_test.cc +++ b/tensorflow/lite/kernels/cast_test.cc @@ -43,7 +43,23 @@ class CastOpModel : public SingleOpModel { int output_; }; -TEST(CastOpModel, CastIntToFloat) { +TEST(CastOpModel, CastInt32ToFloat) { + CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor(m.input(), {100, 200, 300, 400, 500, 600}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f})); +} + +TEST(CastOpModel, CastFloatToInt32) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}}); + m.PopulateTensor(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100, 20, 3, 0, 0, 1})); +} + +TEST(CastOpModel, CastInt64ToFloat) { CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); m.PopulateTensor(m.input(), {100, 200, 300, 400, 500, 600}); m.Invoke(); @@ -51,11 +67,11 @@ TEST(CastOpModel, CastIntToFloat) { ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f})); } -TEST(CastOpModel, CastFloatToInt) { - CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT32, {3, 2}}); +TEST(CastOpModel, CastFloatToInt64) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT64, {3, 2}}); m.PopulateTensor(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f}); m.Invoke(); - EXPECT_THAT(m.ExtractVector(m.output()), + EXPECT_THAT(m.ExtractVector(m.output()), ElementsAreArray({100, 20, 3, 0, 0, 1})); } @@ -75,6 +91,38 @@ TEST(CastOpModel, CastBoolToFloat) { ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); } +TEST(CastOpModel, CastFloatToUInt8) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_UINT8, {3, 2}}); + m.PopulateTensor(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100, 1, 0, 0, 1, 1})); +} + +TEST(CastOpModel, CastUInt8ToFloat) { + CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_FLOAT32, {3, 2}}); + m.PopulateTensor(m.input(), {123, 0, 1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({123.f, 0.f, 1.f, 2.f, 3.f, 4.f})); +} + +TEST(CastOpModel, CastInt32ToUInt8) { + CastOpModel m({TensorType_INT32, {3, 2}}, {TensorType_UINT8, {3, 2}}); + m.PopulateTensor(m.input(), {100, 1, 200, 2, 255, 3}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100, 1, 200, 2, 255, 3})); +} + +TEST(CastOpModel, CastUInt8ToInt32) { + CastOpModel m({TensorType_UINT8, {3, 2}}, {TensorType_INT32, {3, 2}}); + m.PopulateTensor(m.input(), {100, 1, 200, 2, 255, 3}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({100, 1, 200, 2, 255, 3})); +} + TEST(CastOpModel, CastComplex64ToFloat) { CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); m.PopulateTensor>( diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h index c5519b93c6d..40c3ecf3c91 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h +++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h @@ -93,6 +93,7 @@ enum { ANEURALNETWORKS_ARGMAX = 39, ANEURALNETWORKS_ARGMIN = 40, ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM = 42, + ANEURALNETWORKS_CAST = 45, ANEURALNETWORKS_EQUAL = 48, ANEURALNETWORKS_EXP = 49, ANEURALNETWORKS_EXPAND_DIMS = 50,