From 115623e2fc21affeaeee5167daec9c1f0db27069 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Wed, 16 Dec 2020 19:14:57 -0800 Subject: [PATCH] Making Softmax in OpenCL in 3 passes. Improves numerical stability. PiperOrigin-RevId: 347941516 Change-Id: Ibe344c9922e1e267501f42ce1123ec943ee3eb97 --- .../gpu/cl/kernels/softmax1x1_test.cc | 38 +++++++++ .../delegates/gpu/cl/kernels/softmax_test.cc | 38 +++++++++ .../delegates/gpu/common/tasks/softmax.cc | 15 +++- .../delegates/gpu/common/tasks/softmax1x1.cc | 77 +++++++++++-------- 4 files changed, 137 insertions(+), 31 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc index 9db5315f886..717a6f6224b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1_test.cc @@ -59,6 +59,44 @@ TEST_F(OpenCLOperationTest, Softmax1x1) { } } +TEST_F(OpenCLOperationTest, Softmax1x1BigNumber) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 1, 1, 4); + double doubles[4] = {1.0, 2.0, 3.0, 100.0}; + // exp(100) is inf in float (32 bit) but representable in double (64 bit) + src_tensor.data.resize(4); + src_tensor.data[0] = doubles[0]; + src_tensor.data[1] = doubles[1]; + src_tensor.data[2] = doubles[2]; + src_tensor.data[3] = doubles[3]; + EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3]))); + EXPECT_FALSE(std::isinf(std::exp(doubles[3]))); + double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) + + std::exp(doubles[2]) + std::exp(doubles[3]); + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + Softmax1x1 operation = CreateSoftmax1x1(op_def); + ASSERT_OK(ExecuteGPUOperation( + src_tensor, creation_context_, + absl::make_unique(std::move(operation)), BHWC(1, 1, 1, 4), + &dst_tensor)); + EXPECT_THAT( + dst_tensor.data, + Pointwise(FloatNear(eps), + {std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0, + std::exp(doubles[2]) / s0, std::exp(doubles[3]) / s0})); + } + } +} + } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc index 8b1675b0b54..09f247d5e01 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax_test.cc @@ -60,6 +60,44 @@ TEST_F(OpenCLOperationTest, Softmax) { } } +TEST_F(OpenCLOperationTest, SoftmaxBigNumber) { + TensorFloat32 src_tensor; + src_tensor.shape = BHWC(1, 2, 1, 2); + double doubles[4] = {1.0, 2.0, 3.0, 100.0}; + // exp(100) is inf in float (32 bit) but representable in double (64 bit) + src_tensor.data.resize(4); + src_tensor.data[0] = doubles[0]; + src_tensor.data[1] = doubles[1]; + src_tensor.data[2] = doubles[2]; + src_tensor.data[3] = doubles[3]; + EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3]))); + EXPECT_FALSE(std::isinf(std::exp(doubles[3]))); + double s0 = std::exp(doubles[0]) + std::exp(doubles[1]); + double s1 = std::exp(doubles[2]) + std::exp(doubles[3]); + + for (auto storage : env_.GetSupportedStorages()) { + for (auto precision : env_.GetSupportedPrecisions()) { + const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; + OperationDef op_def; + op_def.precision = precision; + auto data_type = DeduceDataTypeFromPrecision(precision); + op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); + op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); + TensorFloat32 dst_tensor; + GPUOperation operation = CreateSoftmax(op_def); + ASSERT_OK(ExecuteGPUOperation( + src_tensor, creation_context_, + absl::make_unique(std::move(operation)), + BHWC(1, 2, 1, 2), &dst_tensor)); + EXPECT_THAT( + dst_tensor.data, + Pointwise(FloatNear(eps), + {std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0, + std::exp(doubles[2]) / s1, std::exp(doubles[3]) / s1})); + } + } +} + } // namespace } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc b/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc index 2bbad5ca48c..09316cd1449 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/softmax.cc @@ -33,15 +33,28 @@ std::string GetSoftmaxKernelCode(const OperationDef& op_def) { c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) " "return; \n"; c += " float sum = 0.0f;\n"; + c += " float maximum = args.src_tensor.Read(X, Y, 0).x;\n"; c += " for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; c += " float4 t = args.src_tensor.Read(X, Y, d);\n"; + c += " maximum = max(maximum, t.x);\n"; + c += " if (d * 4 + 1 < args.dst_tensor.Channels()) maximum = max(maximum, " + "t.y);\n"; + c += " if (d * 4 + 2 < args.dst_tensor.Channels()) maximum = max(maximum, " + "t.z);\n"; + c += " if (d * 4 + 3 < args.dst_tensor.Channels()) maximum = max(maximum, " + "t.w);\n"; + c += " }\n"; + c += " for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; + c += " float4 t = args.src_tensor.Read(X, Y, d) - " + "(float4)(maximum);\n"; c += " sum += exp(t.x);\n"; c += " if (d * 4 + 1 < args.dst_tensor.Channels()) sum += exp(t.y);\n"; c += " if (d * 4 + 2 < args.dst_tensor.Channels()) sum += exp(t.z);\n"; c += " if (d * 4 + 3 < args.dst_tensor.Channels()) sum += exp(t.w);\n"; c += " }\n"; c += " for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; - c += " float4 t = args.src_tensor.Read(X, Y, d);\n"; + c += " float4 t = args.src_tensor.Read(X, Y, d) - " + "(float4)(maximum);\n"; c += " t = exp(t) / sum;\n"; c += " FLT4 result = TO_FLT4(t);\n"; c += " args.dst_tensor.Write(result, X, Y, d);\n"; diff --git a/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc b/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc index 952f08126e1..b5fe6685907 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/softmax1x1.cc @@ -45,7 +45,6 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { args_.AddFloat("mask_y"); args_.AddFloat("mask_z"); args_.AddFloat("mask_w"); - args_.AddInt("slices_x32"); std::string c; c += "__kernel void main_function(\n"; @@ -58,24 +57,47 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { } c += " float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, " "args.mask_w);\n"; - c += " int offset = 0;\n"; - c += " float sum = 0.0f;\n"; - c += " int s = 0;\n"; + c += " float4 maxx4 = (float4)(args.src_tensor.Read(0, 0, 0).x);\n"; c += " int tid = get_local_id(0);\n"; - c += " do {\n"; - c += " int z = offset + tid;\n"; - c += " if (z < args.dst_tensor.Slices()) {\n"; - c += " float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : " + c += " for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n"; + c += " float4 mask_a = s == args.src_tensor.Slices() - 1 ? mask : " "(float4)(1.0f);\n"; - c += " float4 src = args.src_tensor.Read(0, 0, z);\n"; - c += " sum += dot(mask_temp, exp(src));\n"; - c += " offset += 32;\n"; - c += " }\n"; - c += " s++;\n"; - c += " } while (s < args.slices_x32);\n"; - c += "\n"; + c += " float4 mask_b = (float4)(1.0f) - mask_a;\n"; + c += " float4 src = args.src_tensor.Read(0, 0, s);\n"; + c += " src = src * mask_a + mask_b * src.x;\n"; + c += " maxx4 = max(maxx4, src);\n"; + c += " }\n"; + c += " float maximum = max(maxx4.x, maxx4.y);\n"; + c += " maximum = max(maximum, maxx4.z);\n"; + c += " maximum = max(maximum, maxx4.w);\n"; c += " __local float4 tmp[8];\n"; c += " __local float* tmpx1 = (__local float*)tmp;\n"; + c += " tmpx1[tid] = maximum;\n"; + c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; + c += " if (tid == 0) {\n"; + c += " maxx4 = max(tmp[0], tmp[1]);\n"; + c += " maxx4 = max(maxx4, tmp[2]);\n"; + c += " maxx4 = max(maxx4, tmp[3]);\n"; + c += " maxx4 = max(maxx4, tmp[4]);\n"; + c += " maxx4 = max(maxx4, tmp[5]);\n"; + c += " maxx4 = max(maxx4, tmp[6]);\n"; + c += " maxx4 = max(maxx4, tmp[7]);\n"; + c += " maximum = max(maxx4.x, maxx4.y);\n"; + c += " maximum = max(maximum, maxx4.z);\n"; + c += " maximum = max(maximum, maxx4.w);\n"; + c += " tmpx1[0] = maximum;\n"; + c += " }\n"; + c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; + c += " maximum = tmpx1[0];\n"; + c += " float sum = 0.0f;\n"; + c += " for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n"; + c += " float4 mask_temp = s == args.src_tensor.Slices() - 1 ? mask : " + "(float4)(1.0f);\n"; + c += " float4 src = args.src_tensor.Read(0, 0, s) - " + "(float4)(maximum);\n"; + c += " sum += dot(mask_temp, exp(src));\n"; + c += " }\n"; + c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; c += " tmpx1[tid] = sum;\n"; c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; c += " if (tid == 0) {\n"; @@ -92,18 +114,13 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { c += " barrier(CLK_LOCAL_MEM_FENCE);\n"; c += " sum = tmpx1[0];\n"; c += "\n"; - c += " offset = 0;\n"; - c += " s = 0;\n"; - c += " do {\n"; - c += " int z = offset + tid;\n"; - c += " if (z < args.dst_tensor.Slices()) {\n"; - c += " FLT4 res = TO_FLT4(exp(args.src_tensor.Read(0, 0, " - "z))*sum);\n"; - c += " args.dst_tensor.Write(res, 0, 0, z);\n"; - c += " offset += 32;\n"; - c += " }\n"; - c += " s++;\n"; - c += " } while (s < args.slices_x32);\n"; + c += " int dst_s = get_global_id(0);\n"; + c += " if (dst_s < args.dst_tensor.Slices()) {\n"; + c += " float4 src = args.src_tensor.Read(0, 0, dst_s) - " + "(float4)(maximum);\n"; + c += " FLT4 res = TO_FLT4(exp(src) * sum);\n"; + c += " args.dst_tensor.Write(res, 0, 0, dst_s);\n"; + c += " }\n"; c += "}\n"; return c; } @@ -114,12 +131,12 @@ absl::Status Softmax1x1::BindArguments(ArgumentsBinder* args) { RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w)); - RETURN_IF_ERROR( - args->SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); return absl::OkStatus(); } -int3 Softmax1x1::GetGridSize() const { return int3(32, dst_[0]->Batch(), 1); } +int3 Softmax1x1::GetGridSize() const { + return int3(dst_[0]->Slices(), dst_[0]->Batch(), 1); +} Softmax1x1 CreateSoftmax1x1(const OperationDef& definition) { return Softmax1x1(definition);