Making Softmax in OpenCL in 3 passes.
Improves numerical stability. PiperOrigin-RevId: 347941516 Change-Id: Ibe344c9922e1e267501f42ce1123ec943ee3eb97
This commit is contained in:
		
							parent
							
								
									9a03eedc45
								
							
						
					
					
						commit
						115623e2fc
					
				| @ -59,6 +59,44 @@ TEST_F(OpenCLOperationTest, Softmax1x1) { | |||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(OpenCLOperationTest, Softmax1x1BigNumber) { | ||||||
|  |   TensorFloat32 src_tensor; | ||||||
|  |   src_tensor.shape = BHWC(1, 1, 1, 4); | ||||||
|  |   double doubles[4] = {1.0, 2.0, 3.0, 100.0}; | ||||||
|  |   // exp(100) is inf in float (32 bit) but representable in double (64 bit)
 | ||||||
|  |   src_tensor.data.resize(4); | ||||||
|  |   src_tensor.data[0] = doubles[0]; | ||||||
|  |   src_tensor.data[1] = doubles[1]; | ||||||
|  |   src_tensor.data[2] = doubles[2]; | ||||||
|  |   src_tensor.data[3] = doubles[3]; | ||||||
|  |   EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3]))); | ||||||
|  |   EXPECT_FALSE(std::isinf(std::exp(doubles[3]))); | ||||||
|  |   double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) + | ||||||
|  |               std::exp(doubles[2]) + std::exp(doubles[3]); | ||||||
|  | 
 | ||||||
|  |   for (auto storage : env_.GetSupportedStorages()) { | ||||||
|  |     for (auto precision : env_.GetSupportedPrecisions()) { | ||||||
|  |       const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; | ||||||
|  |       OperationDef op_def; | ||||||
|  |       op_def.precision = precision; | ||||||
|  |       auto data_type = DeduceDataTypeFromPrecision(precision); | ||||||
|  |       op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); | ||||||
|  |       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); | ||||||
|  |       TensorFloat32 dst_tensor; | ||||||
|  |       Softmax1x1 operation = CreateSoftmax1x1(op_def); | ||||||
|  |       ASSERT_OK(ExecuteGPUOperation( | ||||||
|  |           src_tensor, creation_context_, | ||||||
|  |           absl::make_unique<Softmax1x1>(std::move(operation)), BHWC(1, 1, 1, 4), | ||||||
|  |           &dst_tensor)); | ||||||
|  |       EXPECT_THAT( | ||||||
|  |           dst_tensor.data, | ||||||
|  |           Pointwise(FloatNear(eps), | ||||||
|  |                     {std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0, | ||||||
|  |                      std::exp(doubles[2]) / s0, std::exp(doubles[3]) / s0})); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace cl
 | }  // namespace cl
 | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
|  | |||||||
| @ -60,6 +60,44 @@ TEST_F(OpenCLOperationTest, Softmax) { | |||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(OpenCLOperationTest, SoftmaxBigNumber) { | ||||||
|  |   TensorFloat32 src_tensor; | ||||||
|  |   src_tensor.shape = BHWC(1, 2, 1, 2); | ||||||
|  |   double doubles[4] = {1.0, 2.0, 3.0, 100.0}; | ||||||
|  |   // exp(100) is inf in float (32 bit) but representable in double (64 bit)
 | ||||||
|  |   src_tensor.data.resize(4); | ||||||
|  |   src_tensor.data[0] = doubles[0]; | ||||||
|  |   src_tensor.data[1] = doubles[1]; | ||||||
|  |   src_tensor.data[2] = doubles[2]; | ||||||
|  |   src_tensor.data[3] = doubles[3]; | ||||||
|  |   EXPECT_TRUE(std::isinf(std::exp(src_tensor.data[3]))); | ||||||
|  |   EXPECT_FALSE(std::isinf(std::exp(doubles[3]))); | ||||||
|  |   double s0 = std::exp(doubles[0]) + std::exp(doubles[1]); | ||||||
|  |   double s1 = std::exp(doubles[2]) + std::exp(doubles[3]); | ||||||
|  | 
 | ||||||
|  |   for (auto storage : env_.GetSupportedStorages()) { | ||||||
|  |     for (auto precision : env_.GetSupportedPrecisions()) { | ||||||
|  |       const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; | ||||||
|  |       OperationDef op_def; | ||||||
|  |       op_def.precision = precision; | ||||||
|  |       auto data_type = DeduceDataTypeFromPrecision(precision); | ||||||
|  |       op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); | ||||||
|  |       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); | ||||||
|  |       TensorFloat32 dst_tensor; | ||||||
|  |       GPUOperation operation = CreateSoftmax(op_def); | ||||||
|  |       ASSERT_OK(ExecuteGPUOperation( | ||||||
|  |           src_tensor, creation_context_, | ||||||
|  |           absl::make_unique<GPUOperation>(std::move(operation)), | ||||||
|  |           BHWC(1, 2, 1, 2), &dst_tensor)); | ||||||
|  |       EXPECT_THAT( | ||||||
|  |           dst_tensor.data, | ||||||
|  |           Pointwise(FloatNear(eps), | ||||||
|  |                     {std::exp(doubles[0]) / s0, std::exp(doubles[1]) / s0, | ||||||
|  |                      std::exp(doubles[2]) / s1, std::exp(doubles[3]) / s1})); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace cl
 | }  // namespace cl
 | ||||||
| }  // namespace gpu
 | }  // namespace gpu
 | ||||||
|  | |||||||
| @ -33,15 +33,28 @@ std::string GetSoftmaxKernelCode(const OperationDef& op_def) { | |||||||
|   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) " |   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) " | ||||||
|        "return; \n"; |        "return; \n"; | ||||||
|   c += "  float sum = 0.0f;\n"; |   c += "  float sum = 0.0f;\n"; | ||||||
|  |   c += "  float maximum = args.src_tensor.Read<float>(X, Y, 0).x;\n"; | ||||||
|   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; |   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; | ||||||
|   c += "    float4 t = args.src_tensor.Read<float>(X, Y, d);\n"; |   c += "    float4 t = args.src_tensor.Read<float>(X, Y, d);\n"; | ||||||
|  |   c += "    maximum = max(maximum, t.x);\n"; | ||||||
|  |   c += "    if (d * 4 + 1 < args.dst_tensor.Channels()) maximum = max(maximum, " | ||||||
|  |        "t.y);\n"; | ||||||
|  |   c += "    if (d * 4 + 2 < args.dst_tensor.Channels()) maximum = max(maximum, " | ||||||
|  |        "t.z);\n"; | ||||||
|  |   c += "    if (d * 4 + 3 < args.dst_tensor.Channels()) maximum = max(maximum, " | ||||||
|  |        "t.w);\n"; | ||||||
|  |   c += "  }\n"; | ||||||
|  |   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; | ||||||
|  |   c += "    float4 t = args.src_tensor.Read<float>(X, Y, d) - " | ||||||
|  |        "(float4)(maximum);\n"; | ||||||
|   c += "    sum += exp(t.x);\n"; |   c += "    sum += exp(t.x);\n"; | ||||||
|   c += "    if (d * 4 + 1 < args.dst_tensor.Channels()) sum += exp(t.y);\n"; |   c += "    if (d * 4 + 1 < args.dst_tensor.Channels()) sum += exp(t.y);\n"; | ||||||
|   c += "    if (d * 4 + 2 < args.dst_tensor.Channels()) sum += exp(t.z);\n"; |   c += "    if (d * 4 + 2 < args.dst_tensor.Channels()) sum += exp(t.z);\n"; | ||||||
|   c += "    if (d * 4 + 3 < args.dst_tensor.Channels()) sum += exp(t.w);\n"; |   c += "    if (d * 4 + 3 < args.dst_tensor.Channels()) sum += exp(t.w);\n"; | ||||||
|   c += "  }\n"; |   c += "  }\n"; | ||||||
|   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; |   c += "  for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n"; | ||||||
|   c += "    float4 t = args.src_tensor.Read<float>(X, Y, d);\n"; |   c += "    float4 t = args.src_tensor.Read<float>(X, Y, d) - " | ||||||
|  |        "(float4)(maximum);\n"; | ||||||
|   c += "    t = exp(t) / sum;\n"; |   c += "    t = exp(t) / sum;\n"; | ||||||
|   c += "    FLT4 result = TO_FLT4(t);\n"; |   c += "    FLT4 result = TO_FLT4(t);\n"; | ||||||
|   c += "    args.dst_tensor.Write(result, X, Y, d);\n"; |   c += "    args.dst_tensor.Write(result, X, Y, d);\n"; | ||||||
|  | |||||||
| @ -45,7 +45,6 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { | |||||||
|   args_.AddFloat("mask_y"); |   args_.AddFloat("mask_y"); | ||||||
|   args_.AddFloat("mask_z"); |   args_.AddFloat("mask_z"); | ||||||
|   args_.AddFloat("mask_w"); |   args_.AddFloat("mask_w"); | ||||||
|   args_.AddInt("slices_x32"); |  | ||||||
| 
 | 
 | ||||||
|   std::string c; |   std::string c; | ||||||
|   c += "__kernel void main_function(\n"; |   c += "__kernel void main_function(\n"; | ||||||
| @ -58,24 +57,47 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { | |||||||
|   } |   } | ||||||
|   c += "  float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, " |   c += "  float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, " | ||||||
|        "args.mask_w);\n"; |        "args.mask_w);\n"; | ||||||
|   c += "  int offset = 0;\n"; |   c += "  float4 maxx4 = (float4)(args.src_tensor.Read<float>(0, 0, 0).x);\n"; | ||||||
|   c += "  float sum = 0.0f;\n"; |  | ||||||
|   c += "  int s = 0;\n"; |  | ||||||
|   c += "  int tid = get_local_id(0);\n"; |   c += "  int tid = get_local_id(0);\n"; | ||||||
|   c += "  do {\n"; |   c += "  for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n"; | ||||||
|   c += "    int z = offset + tid;\n"; |   c += "    float4 mask_a = s == args.src_tensor.Slices() - 1 ? mask : " | ||||||
|   c += "    if (z < args.dst_tensor.Slices()) {\n"; |  | ||||||
|   c += "      float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : " |  | ||||||
|        "(float4)(1.0f);\n"; |        "(float4)(1.0f);\n"; | ||||||
|   c += "      float4 src = args.src_tensor.Read<float>(0, 0, z);\n"; |   c += "    float4 mask_b = (float4)(1.0f) - mask_a;\n"; | ||||||
|   c += "      sum += dot(mask_temp, exp(src));\n"; |   c += "    float4 src = args.src_tensor.Read<float>(0, 0, s);\n"; | ||||||
|   c += "      offset += 32;\n"; |   c += "    src = src * mask_a + mask_b * src.x;\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, src);\n"; | ||||||
|   c += "  }\n"; |   c += "  }\n"; | ||||||
|   c += "    s++;\n"; |   c += "  float maximum = max(maxx4.x, maxx4.y);\n"; | ||||||
|   c += "  } while (s < args.slices_x32);\n"; |   c += "  maximum = max(maximum, maxx4.z);\n"; | ||||||
|   c += "\n"; |   c += "  maximum = max(maximum, maxx4.w);\n"; | ||||||
|   c += "  __local float4 tmp[8];\n"; |   c += "  __local float4 tmp[8];\n"; | ||||||
|   c += "  __local float* tmpx1 = (__local float*)tmp;\n"; |   c += "  __local float* tmpx1 = (__local float*)tmp;\n"; | ||||||
|  |   c += "  tmpx1[tid] = maximum;\n"; | ||||||
|  |   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; | ||||||
|  |   c += "  if (tid == 0) {\n"; | ||||||
|  |   c += "    maxx4 = max(tmp[0], tmp[1]);\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, tmp[2]);\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, tmp[3]);\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, tmp[4]);\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, tmp[5]);\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, tmp[6]);\n"; | ||||||
|  |   c += "    maxx4 = max(maxx4, tmp[7]);\n"; | ||||||
|  |   c += "    maximum = max(maxx4.x, maxx4.y);\n"; | ||||||
|  |   c += "    maximum = max(maximum, maxx4.z);\n"; | ||||||
|  |   c += "    maximum = max(maximum, maxx4.w);\n"; | ||||||
|  |   c += "    tmpx1[0] = maximum;\n"; | ||||||
|  |   c += "  }\n"; | ||||||
|  |   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; | ||||||
|  |   c += "  maximum = tmpx1[0];\n"; | ||||||
|  |   c += "  float sum = 0.0f;\n"; | ||||||
|  |   c += "  for (int s = tid; s < args.src_tensor.Slices(); s += 32) {\n"; | ||||||
|  |   c += "    float4 mask_temp = s == args.src_tensor.Slices() - 1 ? mask : " | ||||||
|  |        "(float4)(1.0f);\n"; | ||||||
|  |   c += "    float4 src = args.src_tensor.Read<float>(0, 0, s) - " | ||||||
|  |        "(float4)(maximum);\n"; | ||||||
|  |   c += "    sum += dot(mask_temp, exp(src));\n"; | ||||||
|  |   c += "  }\n"; | ||||||
|  |   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; | ||||||
|   c += "  tmpx1[tid] = sum;\n"; |   c += "  tmpx1[tid] = sum;\n"; | ||||||
|   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; |   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; | ||||||
|   c += "  if (tid == 0) {\n"; |   c += "  if (tid == 0) {\n"; | ||||||
| @ -92,18 +114,13 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { | |||||||
|   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; |   c += "  barrier(CLK_LOCAL_MEM_FENCE);\n"; | ||||||
|   c += "  sum = tmpx1[0];\n"; |   c += "  sum = tmpx1[0];\n"; | ||||||
|   c += "\n"; |   c += "\n"; | ||||||
|   c += "  offset = 0;\n"; |   c += "  int dst_s = get_global_id(0);\n"; | ||||||
|   c += "  s = 0;\n"; |   c += "  if (dst_s < args.dst_tensor.Slices()) {\n"; | ||||||
|   c += "  do {\n"; |   c += "    float4 src = args.src_tensor.Read<float>(0, 0, dst_s) - " | ||||||
|   c += "    int z = offset + tid;\n"; |        "(float4)(maximum);\n"; | ||||||
|   c += "    if (z < args.dst_tensor.Slices()) {\n"; |   c += "    FLT4 res = TO_FLT4(exp(src) * sum);\n"; | ||||||
|   c += "      FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, " |   c += "    args.dst_tensor.Write(res, 0, 0, dst_s);\n"; | ||||||
|        "z))*sum);\n"; |  | ||||||
|   c += "      args.dst_tensor.Write(res, 0, 0, z);\n"; |  | ||||||
|   c += "      offset += 32;\n"; |  | ||||||
|   c += "  }\n"; |   c += "  }\n"; | ||||||
|   c += "    s++;\n"; |  | ||||||
|   c += "  } while (s < args.slices_x32);\n"; |  | ||||||
|   c += "}\n"; |   c += "}\n"; | ||||||
|   return c; |   return c; | ||||||
| } | } | ||||||
| @ -114,12 +131,12 @@ absl::Status Softmax1x1::BindArguments(ArgumentsBinder* args) { | |||||||
|   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); |   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); | ||||||
|   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); |   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); | ||||||
|   RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w)); |   RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w)); | ||||||
|   RETURN_IF_ERROR( |  | ||||||
|       args->SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); |  | ||||||
|   return absl::OkStatus(); |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int3 Softmax1x1::GetGridSize() const { return int3(32, dst_[0]->Batch(), 1); } | int3 Softmax1x1::GetGridSize() const { | ||||||
|  |   return int3(dst_[0]->Slices(), dst_[0]->Batch(), 1); | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| Softmax1x1 CreateSoftmax1x1(const OperationDef& definition) { | Softmax1x1 CreateSoftmax1x1(const OperationDef& definition) { | ||||||
|   return Softmax1x1(definition); |   return Softmax1x1(definition); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user