Using new defines in elementwise kernels to make unified code.

PiperOrigin-RevId: 351643998
Change-Id: Ia09df5b827b3d59c823921d0dca37a9c8ec2c51c
This commit is contained in:
Raman Sarokin 2021-01-13 12:39:03 -08:00 committed by TensorFlower Gardener
parent e2b680adb7
commit 100b443d8f
9 changed files with 102 additions and 86 deletions

View File

@ -64,6 +64,7 @@ std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
result += "#define TO_FLT4 convert_float4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
result += "#define INIT_FLT(value) (float)(value)\n";
result += "#define INIT_FLT4(value) (float4)(value)\n";
break;
case CalculationsPrecision::F16:
@ -77,6 +78,7 @@ std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_half4\n";
result += "#define TO_ACCUM_FLT convert_half\n";
result += "#define INIT_FLT(value) (half)(value)\n";
result += "#define INIT_FLT4(value) (half4)(value)\n";
break;
case CalculationsPrecision::F32_F16:
@ -90,6 +92,7 @@ std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
result += "#define TO_FLT4 convert_half4\n";
result += "#define TO_ACCUM_TYPE convert_float4\n";
result += "#define TO_ACCUM_FLT convert_float\n";
result += "#define INIT_FLT(value) (half)(value)\n";
result += "#define INIT_FLT4(value) (half4)(value)\n";
break;
}

View File

@ -45,8 +45,8 @@ TEST_F(OpenCLOperationTest, Abs) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::ABS);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::ABS);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -72,8 +72,8 @@ TEST_F(OpenCLOperationTest, Cos) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::COS);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::COS);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -99,8 +99,8 @@ TEST_F(OpenCLOperationTest, Copy) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::COPY);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::COPY);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -124,8 +124,8 @@ TEST_F(OpenCLOperationTest, Elu) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::ELU);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::ELU);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -152,8 +152,8 @@ TEST_F(OpenCLOperationTest, Exp) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::EXP);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::EXP);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -181,8 +181,8 @@ TEST_F(OpenCLOperationTest, HardSwish) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::HARD_SWISH);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::HARD_SWISH);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -209,8 +209,8 @@ TEST_F(OpenCLOperationTest, Log) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::LOG);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::LOG);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -236,8 +236,8 @@ TEST_F(OpenCLOperationTest, Neg) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::NEG);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::NEG);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -262,8 +262,8 @@ TEST_F(OpenCLOperationTest, Rsqrt) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::RSQRT);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::RSQRT);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -291,8 +291,8 @@ TEST_F(OpenCLOperationTest, Sigmoid) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::SIGMOID);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::SIGMOID);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -317,8 +317,8 @@ TEST_F(OpenCLOperationTest, Sin) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::SIN);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::SIN);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -345,8 +345,8 @@ TEST_F(OpenCLOperationTest, Sqrt) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::SQRT);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::SQRT);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -373,8 +373,8 @@ TEST_F(OpenCLOperationTest, Square) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::SQUARE);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::SQUARE);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),
@ -399,8 +399,8 @@ TEST_F(OpenCLOperationTest, Tanh) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreateElementwiseOneInput(op_def, OperationType::TANH);
GPUOperation operation = CreateElementwiseOneInput(
creation_context_.GetGpuInfo(), op_def, OperationType::TANH);
ASSERT_OK(ExecuteGPUOperation(
src_tensor, creation_context_,
absl::make_unique<GPUOperation>(std::move(operation)),

View File

@ -41,7 +41,7 @@ TEST_F(OpenCLOperationTest, MeanHW) {
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
@ -68,7 +68,7 @@ TEST_F(OpenCLOperationTest, ReduceSumChannels) {
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
@ -95,7 +95,7 @@ TEST_F(OpenCLOperationTest, ReduceProductChannels) {
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
@ -123,7 +123,7 @@ TEST_F(OpenCLOperationTest, ReduceMaxChannels) {
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
@ -151,7 +151,7 @@ TEST_F(OpenCLOperationTest, ReduceMinChannels) {
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);

View File

@ -481,7 +481,8 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
case OperationType::SQRT:
case OperationType::SQUARE:
case OperationType::TANH: {
GPUOperation operation = CreateElementwiseOneInput(op_def, op_type);
GPUOperation operation =
CreateElementwiseOneInput(gpu_info, op_def, op_type);
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
}

View File

@ -25,7 +25,8 @@ namespace tflite {
namespace gpu {
namespace {
std::string GetOneInputCode(const OperationType& op_type,
std::string GetOneInputCode(const GpuInfo& gpu_info,
const OperationType& op_type,
CalculationsPrecision precision,
const std::string& input0) {
std::string result;
@ -41,18 +42,28 @@ std::string GetOneInputCode(const OperationType& op_type,
result = "\n";
break;
case OperationType::ELU:
result = "$0.x = $0.x < (FLT)(0.0f) ? expm1($0.x) : $0.x;\n";
result += "$0.y = $0.y < (FLT)(0.0f) ? expm1($0.y) : $0.y;\n";
result += "$0.z = $0.z < (FLT)(0.0f) ? expm1($0.z) : $0.z;\n";
result += "$0.w = $0.w < (FLT)(0.0f) ? expm1($0.w) : $0.w;\n";
if (gpu_info.IsApiOpenCl()) {
result = R"(
$0.x = $0.x < INIT_FLT(0.0f) ? expm1($0.x) : $0.x;
$0.y = $0.y < INIT_FLT(0.0f) ? expm1($0.y) : $0.y;
$0.z = $0.z < INIT_FLT(0.0f) ? expm1($0.z) : $0.z;
$0.w = $0.w < INIT_FLT(0.0f) ? expm1($0.w) : $0.w;)";
} else {
result = R"(
$0.x = $0.x < INIT_FLT(0.0f) ? exp($0.x) - INIT_FLT(1.0f) : $0.x;
$0.y = $0.y < INIT_FLT(0.0f) ? exp($0.y) - INIT_FLT(1.0f) : $0.y;
$0.z = $0.z < INIT_FLT(0.0f) ? exp($0.z) - INIT_FLT(1.0f) : $0.z;
$0.w = $0.w < INIT_FLT(0.0f) ? exp($0.w) - INIT_FLT(1.0f) : $0.w;)";
}
break;
case OperationType::EXP:
result = "$0 = exp($0);\n";
break;
case OperationType::HARD_SWISH:
result =
"$0 *= clamp($0 * (FLT)(0.16666667f) + (FLT)(0.5f), (FLT4)(0.0f), "
"(FLT4)(1.0f));\n";
"$0 *= clamp($0 * INIT_FLT(0.16666667f) + INIT_FLT(0.5f), "
"INIT_FLT4(0.0f), "
"INIT_FLT4(1.0f));\n";
break;
case OperationType::LOG:
result = "$0 = log($0);\n";
@ -64,12 +75,12 @@ std::string GetOneInputCode(const OperationType& op_type,
result = "$0 = rsqrt($0);\n";
break;
case OperationType::SIGMOID:
if (precision != CalculationsPrecision::F32) {
if (gpu_info.IsApiOpenCl() && precision != CalculationsPrecision::F32) {
result =
"$0 = convert_half4(native_recip(1.0f + "
"native_exp(convert_float4(-$0))));\n";
} else {
result = "$0 = (FLT4)(1.0f) / ((FLT4)(1.0f) + exp(-($0)));\n";
result = "$0 = INIT_FLT4(1.0f) / (INIT_FLT4(1.0f) + exp(-($0)));\n";
}
break;
case OperationType::SIN:
@ -123,40 +134,40 @@ std::string GetTwoInputCode(const OperationType& op_type,
break;
// Comparison operators
case OperationType::LESS:
result = "$0.x = $1.x < $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y < $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z < $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w < $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result = "$0.x = $1.x < $2.x ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.y = $1.y < $2.y ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.z = $1.z < $2.z ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.w = $1.w < $2.w ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
break;
case OperationType::LESS_EQUAL:
result = "$0.x = $1.x <= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y <= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z <= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w <= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result = "$0.x = $1.x <= $2.x ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.y = $1.y <= $2.y ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.z = $1.z <= $2.z ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.w = $1.w <= $2.w ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
break;
case OperationType::GREATER:
result = "$0.x = $1.x > $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y > $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z > $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w > $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result = "$0.x = $1.x > $2.x ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.y = $1.y > $2.y ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.z = $1.z > $2.z ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.w = $1.w > $2.w ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
break;
case OperationType::GREATER_EQUAL:
result = "$0.x = $1.x >= $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y >= $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z >= $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w >= $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result = "$0.x = $1.x >= $2.x ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.y = $1.y >= $2.y ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.z = $1.z >= $2.z ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.w = $1.w >= $2.w ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
break;
case OperationType::EQUAL:
result = "$0.x = $1.x == $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y == $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z == $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w == $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result = "$0.x = $1.x == $2.x ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.y = $1.y == $2.y ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.z = $1.z == $2.z ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.w = $1.w == $2.w ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
break;
case OperationType::NOT_EQUAL:
result = "$0.x = $1.x != $2.x ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.y = $1.y != $2.y ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.z = $1.z != $2.z ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result += "$0.w = $1.w != $2.w ? (FLT)(1.0f) : (FLT)(0.0f);\n";
result = "$0.x = $1.x != $2.x ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.y = $1.y != $2.y ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.z = $1.z != $2.z ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
result += "$0.w = $1.w != $2.w ? INIT_FLT(1.0f) : INIT_FLT(0.0f);\n";
break;
default:
return "Unknown operation type;\n";
@ -180,9 +191,7 @@ GPUOperation CreateElementwiseOneRuntimeOneScalar(
} else {
op.args_.AddHalf("scalar", half(scalar_parameter));
}
op.code_ =
"FLT4 second_val = (FLT4)(args.scalar, args.scalar, args.scalar, "
"args.scalar);\n";
op.code_ = "FLT4 second_val = INIT_FLT4(args.scalar);\n";
op.code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
"second_val", swap_inputs);
return op;
@ -256,11 +265,13 @@ GPUOperation CreateElementwiseTwoInput(
} // namespace
GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
const OperationDef& definition,
const OperationType& op_type) {
GPUOperation op(definition);
op.elementwise_ = true;
op.code_ = GetOneInputCode(op_type, definition.precision, "in_out_value");
op.code_ =
GetOneInputCode(gpu_info, op_type, definition.precision, "in_out_value");
return op;
}

View File

@ -27,7 +27,8 @@ namespace gpu {
// Creates simple one input operation without any parameters, for example
// log, sin, cos, etc.
GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
const OperationDef& definition,
const OperationType& op_type);
// Creates simple two input(first input is runtime tensor and second input is

View File

@ -73,15 +73,15 @@ GPUOperation CreatePReLU(const GpuInfo& gpu_info,
} else {
result.args_.AddHalf("clip", half(attr.clip));
}
result.code_ =
alpha_read +
"in_out_value = clamp(in_out_value, (FLT4)(0.0f), (FLT4)(args.clip)) + "
"min((FLT4)(0.0f), in_out_value) * alpha_val;";
result.code_ = alpha_read +
"in_out_value = clamp(in_out_value, INIT_FLT4(0.0f), "
"INIT_FLT4(args.clip)) + "
"min(INIT_FLT4(0.0f), in_out_value) * alpha_val;";
} else {
result.code_ =
alpha_read +
"in_out_value = max((FLT4)(0.0f), in_out_value) + min((FLT4)(0.0f), "
"in_out_value) * alpha_val;";
result.code_ = alpha_read +
"in_out_value = max(INIT_FLT4(0.0f), in_out_value) + "
"min(INIT_FLT4(0.0f), "
"in_out_value) * alpha_val;";
}
return result;

View File

@ -49,9 +49,9 @@ GPUOperation CreateQuantizeAndDequantize(
op.args_.AddHalf("scale", half(adjusted_attr.scale));
}
op.code_ = R"(
FLT4 clamped_value = min((FLT4)(args.max), max((FLT4)(args.min), in_out_value));
FLT4 quantized_value = round((clamped_value - (FLT4)(args.min)) / (FLT4)(args.scale));
FLT4 dequantized_value = quantized_value * (FLT4)(args.scale) + (FLT4)(args.min);
FLT4 clamped_value = min(INIT_FLT4(args.max), max(INIT_FLT4(args.min), in_out_value));
FLT4 quantized_value = round((clamped_value - INIT_FLT4(args.min)) / INIT_FLT4(args.scale));
FLT4 dequantized_value = quantized_value * INIT_FLT4(args.scale) + INIT_FLT4(args.min);
in_out_value = dequantized_value;)";
return op;

View File

@ -27,14 +27,14 @@ GPUOperation CreateReLU(const OperationDef& definition,
std::string min_func;
if (attr.alpha != 0.0f) {
min_func = "min(in_out_value * args.alpha, (FLT)(0.0f))";
min_func = "min(in_out_value * args.alpha, INIT_FLT(0.0f))";
if (definition.precision == CalculationsPrecision::F32) {
op.args_.AddFloat("alpha", attr.alpha);
} else {
op.args_.AddHalf("alpha", half(attr.alpha));
}
} else {
min_func = "(FLT)(0.0f)";
min_func = "INIT_FLT(0.0f)";
}
if (attr.clip != 0.0f) {
if (definition.precision == CalculationsPrecision::F32) {