diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc index a81c8be9ea1..29565b0910e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc @@ -131,6 +131,43 @@ std::string GetTwoInputCode(const OperationType& op_type, case OperationType::SUB: result += "$0 = $1 - $2;\n"; break; + // Comparison operators + case OperationType::LESS: + result = "$0.x = $1.x < $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y < $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z < $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w < $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::LESS_EQUAL: + result = "$0.x = $1.x <= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y <= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z <= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w <= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::GREATER: + result = "$0.x = $1.x > $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y > $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z > $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w > $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::GREATER_EQUAL: + result = "$0.x = $1.x >= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y >= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z >= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w >= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::EQUAL: + result = "$0.x = $1.x == $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y == $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z == $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w == $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; + case OperationType::NOT_EQUAL: + result = "$0.x = $1.x != $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.y = $1.y != $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.z = $1.z != $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + result += "$0.w = $1.w != $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n"; + break; default: return "Unknown operation type;\n"; } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc index b6b501bd11b..b48f66ce600 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise_test.cc @@ -841,6 +841,174 @@ TEST_F(OpenCLOperationTest, SubWithScalarAtFirstPosition) { } } +TEST_F(OpenCLOperationTest, Less) { + TensorFloat32 src_tensor_0, src_tensor_1; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_1.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + src_tensor_1.data = {1.0f, 0.0f, 2.0f, -4.0f}; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateElementwiseTwoInput( + op_def, OperationType::LESS, src_tensor_1.shape); + ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1}, + creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.0f, 0.0f, 0.0f, 0.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, LessEqual) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::LESS_EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.0f, 1.0f, 1.0f, 0.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, Greater) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::GREATER, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 0.0f, 0.0f, 1.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, GreaterEqual) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::GREATER_EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 1.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, Equal) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 0.0f})); + } + } +} + +TEST_F(OpenCLOperationTest, NotEqual) { + TensorFloat32 src_tensor_0; + src_tensor_0.shape = BHWC(1, 2, 1, 2); + src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f}; + + ElementwiseAttributes attr; + attr.param = 2.0f; + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = + CreateElementwise(creation_context_.GetDeviceInfo(), op_def, + OperationType::NOT_EQUAL, attr); + ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT(dst_tensor.data, + Pointwise(FloatNear(eps), {1.0f, 1.0f, 0.0f, 1.0f})); + } + } +} + } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index d259e989dd9..0bd78103409 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -348,9 +348,15 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, return absl::OkStatus(); } case OperationType::DIV: + case OperationType::EQUAL: + case OperationType::GREATER: + case OperationType::GREATER_EQUAL: + case OperationType::LESS: + case OperationType::LESS_EQUAL: case OperationType::MAXIMUM: case OperationType::MINIMUM: case OperationType::MUL: + case OperationType::NOT_EQUAL: case OperationType::POW: case OperationType::SQUARED_DIFF: case OperationType::SUB: {