diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 0843fe5d5dc..02f5f9c4a4a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -941,6 +941,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:cl_context", "//tensorflow/lite/delegates/gpu/cl:cl_kernel", "//tensorflow/lite/delegates/gpu/cl:linear_storage", + "//tensorflow/lite/delegates/gpu/cl:storage_type_util", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:status", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc index 7a29d5752fe..bcda1f6a628 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc @@ -18,47 +18,75 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" namespace tflite { namespace gpu { namespace cl { -absl::Status CreatePReLU(const CreationContext& creation_context, +GPUOperation CreatePReLU(const DeviceInfo& device_info, const OperationDef& definition, - const PReLUAttributes& attr, GPUOperation* result) { - *result = GPUOperation(definition); - result->elementwise_ = true; + const PReLUAttributes& attr) { + GPUOperation result(definition); + result.elementwise_ = true; + + std::string alpha_read; + auto alpha_linear = + absl::get_if>(&attr.alpha); + if (alpha_linear) { + TensorLinearDescriptor desc; + desc.storage_type = + DeduceLinearStorageType(definition.GetPrimaryStorageType()); + desc.element_type = definition.GetPrimaryDataType(); + desc.UploadLinearData(*alpha_linear); + result.args_.AddObject( + "alpha", absl::make_unique(std::move(desc))); + alpha_read = "FLT4 alpha_val = args.alpha.Read(S_COORD);\n"; + } + + auto alpha_hwc = + absl::get_if>(&attr.alpha); + if (alpha_hwc) { + const BHWC shape = + BHWC(1, alpha_hwc->shape.h, alpha_hwc->shape.w, alpha_hwc->shape.c); + TensorStorageType storage_type = SelectBestStorageType( + device_info, shape, definition.GetPrimaryStorageType(), + definition.GetDataType(), Layout::HWC); + TensorDescriptor desc{definition.GetDataType(), storage_type, Layout::HWC}; + desc.UploadData(*alpha_hwc); + result.args_.AddObject( + "alpha", absl::make_unique(std::move(desc))); + const std::string x_coord = shape.w == 1 ? "0" : "X_COORD"; + const std::string y_coord = shape.h == 1 ? "0" : "Y_COORD"; + const std::string s_coord = shape.c == 1 ? "0" : "S_COORD"; + alpha_read = absl::StrCat("FLT4 alpha_val = args.alpha.Read(", x_coord, + ", ", y_coord, ", ", s_coord, ");\n"); + if (shape.c == 1) { + alpha_read += " alpha_val.y = alpha_val.x;\n"; + alpha_read += " alpha_val.z = alpha_val.x;\n"; + alpha_read += " alpha_val.w = alpha_val.x;\n"; + } + } + if (attr.clip != 0) { if (definition.precision == CalculationsPrecision::F32) { - result->args_.AddFloat("clip", attr.clip); + result.args_.AddFloat("clip", attr.clip); } else { - result->args_.AddHalf("clip", half(attr.clip)); + result.args_.AddHalf("clip", half(attr.clip)); } - result->code_ = + result.code_ = + alpha_read + "in_out_value = clamp(in_out_value, (FLT4)(0.0f), (FLT4)(args.clip)) + " - "min((FLT4)(0.0f), in_out_value) * args.alpha.Read(S_COORD);"; + "min((FLT4)(0.0f), in_out_value) * alpha_val;"; } else { - result->code_ = + result.code_ = + alpha_read + "in_out_value = max((FLT4)(0.0f), in_out_value) + min((FLT4)(0.0f), " - "in_out_value) * args.alpha.Read(S_COORD);"; + "in_out_value) * alpha_val;"; } - auto alpha = - absl::get_if>(&attr.alpha); - if (!alpha) { - return absl::InvalidArgumentError("Alpha is missing"); - } - TensorLinearDescriptor desc; - desc.storage_type = - DeduceLinearStorageType(definition.GetPrimaryStorageType()); - desc.element_type = definition.GetPrimaryDataType(); - desc.UploadLinearData(*alpha); - - result->args_.AddObject( - "alpha", absl::make_unique(std::move(desc))); - - return absl::OkStatus(); + return result; } } // namespace cl diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h index b673217c799..5d2a41bc6de 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h @@ -31,9 +31,9 @@ namespace tflite { namespace gpu { namespace cl { -absl::Status CreatePReLU(const CreationContext& creation_context, +GPUOperation CreatePReLU(const DeviceInfo& device_info, const OperationDef& definition, - const PReLUAttributes& attr, GPUOperation* result); + const PReLUAttributes& attr); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc index 06ff09ccca7..ef4b8c17324 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc @@ -52,8 +52,8 @@ TEST_F(OpenCLOperationTest, PReLUAlpha) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - GPUOperation operation; - ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation)); + GPUOperation operation = + CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 1, 2), &dst_tensor)); EXPECT_THAT(dst_tensor.data, @@ -83,8 +83,8 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - GPUOperation operation; - ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation)); + GPUOperation operation = + CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr); ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, BHWC(1, 2, 1, 2), &dst_tensor)); EXPECT_THAT(dst_tensor.data, @@ -93,6 +93,37 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) { } } +TEST_F(OpenCLOperationTest, PReLUHWCAlpha) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 2); + src_tensor.data = {0.0f, -1.0f, -2.0f, 3.0f}; + + PReLUAttributes attr; + ::tflite::gpu::Tensor hwc_tensor; + hwc_tensor.shape = HWC(2, 1, 2); + hwc_tensor.data = {0.5f, -2.0f, 0.7f, 4.7f}; + attr.alpha = hwc_tensor; + attr.clip = 0.0; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 2.0f, -1.4f, 3.0f})); + } + } +} + } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index dc18cde25c2..4d67dd60a50 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -284,7 +284,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context, } case OperationType::PRELU: { auto attr = absl::any_cast(node.operation.attributes); - return SelectPReLU(attr, creation_context, op_def, gpu_op); + *gpu_op = SelectPReLU(attr, creation_context.GetDeviceInfo(), op_def); + return absl::OkStatus(); } case OperationType::QUANTIZE_AND_DEQUANTIZE: { auto attr = absl::any_cast( diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index 5f2f8f05cb2..4464342be16 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -56,14 +56,11 @@ std::unique_ptr SelectReLU(const ReLUAttributes& attr, return absl::make_unique(CreateReLU(op_def, attr)); } -absl::Status SelectPReLU(const PReLUAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr) { - GPUOperation operation; - RETURN_IF_ERROR(CreatePReLU(creation_context, op_def, attr, &operation)); - *ptr = absl::make_unique(std::move(operation)); - return absl::OkStatus(); +std::unique_ptr SelectPReLU(const PReLUAttributes& attr, + const DeviceInfo& device_info, + const OperationDef& op_def) { + return absl::make_unique( + CreatePReLU(device_info, op_def, attr)); } void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def, diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index 71d4c1f5c07..2a97e8aac08 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -34,10 +34,9 @@ void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info, std::unique_ptr SelectReLU(const ReLUAttributes& attr, const OperationDef& op_def); -absl::Status SelectPReLU(const PReLUAttributes& attr, - const CreationContext& creation_context, - const OperationDef& op_def, - std::unique_ptr* ptr); +std::unique_ptr SelectPReLU(const PReLUAttributes& attr, + const DeviceInfo& device_info, + const OperationDef& op_def); void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def, std::unique_ptr* ptr);