Added comparison operations to OpenCL backend.
PiperOrigin-RevId: 330028303 Change-Id: I320ebb451c508c639479a6729470a9cd883df882
This commit is contained in:
parent
4cd4a57c3b
commit
a982edc004
tensorflow/lite/delegates/gpu/cl
@ -131,6 +131,43 @@ std::string GetTwoInputCode(const OperationType& op_type,
|
||||
case OperationType::SUB:
|
||||
result += "$0 = $1 - $2;\n";
|
||||
break;
|
||||
// Comparison operators
|
||||
case OperationType::LESS:
|
||||
result = "$0.x = $1.x < $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.y = $1.y < $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.z = $1.z < $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.w = $1.w < $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
break;
|
||||
case OperationType::LESS_EQUAL:
|
||||
result = "$0.x = $1.x <= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.y = $1.y <= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.z = $1.z <= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.w = $1.w <= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
break;
|
||||
case OperationType::GREATER:
|
||||
result = "$0.x = $1.x > $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.y = $1.y > $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.z = $1.z > $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.w = $1.w > $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
break;
|
||||
case OperationType::GREATER_EQUAL:
|
||||
result = "$0.x = $1.x >= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.y = $1.y >= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.z = $1.z >= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.w = $1.w >= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
break;
|
||||
case OperationType::EQUAL:
|
||||
result = "$0.x = $1.x == $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.y = $1.y == $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.z = $1.z == $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.w = $1.w == $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
break;
|
||||
case OperationType::NOT_EQUAL:
|
||||
result = "$0.x = $1.x != $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.y = $1.y != $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.z = $1.z != $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
result += "$0.w = $1.w != $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
|
||||
break;
|
||||
default:
|
||||
return "Unknown operation type;\n";
|
||||
}
|
||||
|
@ -841,6 +841,174 @@ TEST_F(OpenCLOperationTest, SubWithScalarAtFirstPosition) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Less) {
|
||||
TensorFloat32 src_tensor_0, src_tensor_1;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_1.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
src_tensor_1.data = {1.0f, 0.0f, 2.0f, -4.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation = CreateElementwiseTwoInput(
|
||||
op_def, OperationType::LESS, src_tensor_1.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {1.0f, 0.0f, 0.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, LessEqual) {
|
||||
TensorFloat32 src_tensor_0;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 2.0f;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::LESS_EQUAL, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {1.0f, 1.0f, 1.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Greater) {
|
||||
TensorFloat32 src_tensor_0;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 2.0f;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::GREATER, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.0f, 0.0f, 0.0f, 1.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, GreaterEqual) {
|
||||
TensorFloat32 src_tensor_0;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 2.0f;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::GREATER_EQUAL, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 1.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Equal) {
|
||||
TensorFloat32 src_tensor_0;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 2.0f;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::EQUAL, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, NotEqual) {
|
||||
TensorFloat32 src_tensor_0;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
attr.param = 2.0f;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::NOT_EQUAL, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {1.0f, 1.0f, 0.0f, 1.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -348,9 +348,15 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::DIV:
|
||||
case OperationType::EQUAL:
|
||||
case OperationType::GREATER:
|
||||
case OperationType::GREATER_EQUAL:
|
||||
case OperationType::LESS:
|
||||
case OperationType::LESS_EQUAL:
|
||||
case OperationType::MAXIMUM:
|
||||
case OperationType::MINIMUM:
|
||||
case OperationType::MUL:
|
||||
case OperationType::NOT_EQUAL:
|
||||
case OperationType::POW:
|
||||
case OperationType::SQUARED_DIFF:
|
||||
case OperationType::SUB: {
|
||||
|
Loading…
Reference in New Issue
Block a user