diff --git a/tensorflow/core/kernels/mlir_generated_cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/mlir_generated_cwise_op_gpu_tanh.cu.cc index 40dd7c7e49e..70de777239f 100644 --- a/tensorflow/core/kernels/mlir_generated_cwise_op_gpu_tanh.cu.cc +++ b/tensorflow/core/kernels/mlir_generated_cwise_op_gpu_tanh.cu.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -45,9 +46,41 @@ Status CreateKernel(absl::string_view kernel_name, uint64_t num_args, return stream_exec->GetKernel(loader_spec, kernel_base.get()); } -class MlirGenerateTanhOp : public OpKernel { +struct LaunchConfig { + se::BlockDim blockDim; + se::ThreadDim threadDim; +}; + +LaunchConfig GetLaunchConfiguration(std::vector tile_sizes, + std::vector unrolling_factors, + std::vector shape) { + LaunchConfig result; + // Ensure the vectors are length 3 and pad with ones. + tile_sizes.resize(3, 1); + unrolling_factors.resize(3, 1); + shape.resize(3, 1); + // The number of threads is given by the tiling size. + result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]); + // We know that the kernel was generated by mapping the three outer-most + // dimensions to x,y,z dimensions. So we only need to compute those. + std::vector block_dims(3); + for (int i = 0; i < 3; ++i) { + // Compute the number of grids. We use ceildiv here as we have to allocate + // an extra thread/block if the division is not even. The kernel contains + // code to handle the boundaries. + int number_of_threads = + (shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i]; + int number_of_grids = + (number_of_threads + tile_sizes[i] - 1) / tile_sizes[i]; + block_dims[i] = number_of_grids; + } + result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]); + return result; +} + +class MlirGeneratedTanhOp : public OpKernel { public: - explicit MlirGenerateTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { auto* stream = ctx->op_device_context()->stream(); @@ -88,11 +121,13 @@ class MlirGenerateTanhOp : public OpKernel { args.add_argument(inp.NumElements()); args.add_argument(1); - // TODO(b/158649746): Choose block size and thread dim according to the - // number of input elements. For now, this supports at most 1024 elements. + // This has to be aligned with the configuration that was used when building + // the kernels. See the corresponding build rules in `cubin_headers/BUILD`. + LaunchConfig config = GetLaunchConfiguration( + {256}, {}, {static_cast(inp.NumElements())}); OP_REQUIRES_OK( - ctx, stream->parent()->Launch(stream, se::ThreadDim(inp.NumElements()), - se::BlockDim(1), *kernel, args)); + ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim, + *kernel, args)); } protected: @@ -103,26 +138,26 @@ class MlirGenerateTanhOp : public OpKernel { std::mutex mu_; }; -class MlirGenerateTanhF16Op : public MlirGenerateTanhOp { +class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp { public: - explicit MlirGenerateTanhF16Op(OpKernelConstruction* ctx) - : MlirGenerateTanhOp(ctx) { + explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx) + : MlirGeneratedTanhOp(ctx) { cubin_data_ = kTanhF16Kernel; } }; -class MlirGenerateTanhF32Op : public MlirGenerateTanhOp { +class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp { public: - explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx) - : MlirGenerateTanhOp(ctx) { + explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx) + : MlirGeneratedTanhOp(ctx) { cubin_data_ = kTanhF32Kernel; } }; -class MlirGenerateTanhF64Op : public MlirGenerateTanhOp { +class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp { public: - explicit MlirGenerateTanhF64Op(OpKernelConstruction* ctx) - : MlirGenerateTanhOp(ctx) { + explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx) + : MlirGeneratedTanhOp(ctx) { cubin_data_ = kTanhF64Kernel; } }; @@ -130,11 +165,11 @@ class MlirGenerateTanhF64Op : public MlirGenerateTanhOp { REGISTER_KERNEL_BUILDER( Name("Tanh").Device(DEVICE_GPU).TypeConstraint("T"), - MlirGenerateTanhF16Op); + MlirGeneratedTanhF16Op); REGISTER_KERNEL_BUILDER( Name("Tanh").Device(DEVICE_GPU).TypeConstraint("T"), - MlirGenerateTanhF32Op); + MlirGeneratedTanhF32Op); REGISTER_KERNEL_BUILDER( Name("Tanh").Device(DEVICE_GPU).TypeConstraint("T"), - MlirGenerateTanhF64Op); + MlirGeneratedTanhF64Op); } // namespace tensorflow