Added comparison operations to OpenCL backend.

PiperOrigin-RevId: 330028303
Change-Id: I320ebb451c508c639479a6729470a9cd883df882
This commit is contained in:
Raman Sarokin 2020-09-03 17:53:34 -07:00 committed by TensorFlower Gardener
parent 4cd4a57c3b
commit a982edc004
3 changed files with 211 additions and 0 deletions

View File

@ -131,6 +131,43 @@ std::string GetTwoInputCode(const OperationType& op_type,
case OperationType::SUB:
result += "$0 = $1 - $2;\n";
break;
// Comparison operators
case OperationType::LESS:
result = "$0.x = $1.x < $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y < $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z < $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w < $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
break;
case OperationType::LESS_EQUAL:
result = "$0.x = $1.x <= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y <= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z <= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w <= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
break;
case OperationType::GREATER:
result = "$0.x = $1.x > $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y > $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z > $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w > $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
break;
case OperationType::GREATER_EQUAL:
result = "$0.x = $1.x >= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y >= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z >= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w >= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
break;
case OperationType::EQUAL:
result = "$0.x = $1.x == $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y == $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z == $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w == $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
break;
case OperationType::NOT_EQUAL:
result = "$0.x = $1.x != $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y != $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z != $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w != $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
break;
default:
return "Unknown operation type;\n";
}

View File

@ -841,6 +841,174 @@ TEST_F(OpenCLOperationTest, SubWithScalarAtFirstPosition) {
}
}
TEST_F(OpenCLOperationTest, Less) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
src_tensor_1.data = {1.0f, 0.0f, 2.0f, -4.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation = CreateElementwiseTwoInput(
op_def, OperationType::LESS, src_tensor_1.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {1.0f, 0.0f, 0.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, LessEqual) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
ElementwiseAttributes attr;
attr.param = 2.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
OperationType::LESS_EQUAL, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {1.0f, 1.0f, 1.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, Greater) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
ElementwiseAttributes attr;
attr.param = 2.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
OperationType::GREATER, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, 0.0f, 0.0f, 1.0f}));
}
}
}
TEST_F(OpenCLOperationTest, GreaterEqual) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
ElementwiseAttributes attr;
attr.param = 2.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
OperationType::GREATER_EQUAL, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 1.0f}));
}
}
}
TEST_F(OpenCLOperationTest, Equal) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
ElementwiseAttributes attr;
attr.param = 2.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
OperationType::EQUAL, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, 0.0f, 1.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, NotEqual) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {0.0f, 1.0f, 2.0f, 3.0f};
ElementwiseAttributes attr;
attr.param = 2.0f;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
OperationType::NOT_EQUAL, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {1.0f, 1.0f, 0.0f, 1.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu

View File

@ -348,9 +348,15 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
return absl::OkStatus();
}
case OperationType::DIV:
case OperationType::EQUAL:
case OperationType::GREATER:
case OperationType::GREATER_EQUAL:
case OperationType::LESS:
case OperationType::LESS_EQUAL:
case OperationType::MAXIMUM:
case OperationType::MINIMUM:
case OperationType::MUL:
case OperationType::NOT_EQUAL:
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB: {