Properly configure the block and grid dimensions when launching generated kernels.

PiperOrigin-RevId: 316867392
Change-Id: I975a9d1e29e954760532a82985dd016707aa6d02
This commit is contained in:
Adrian Kuegel 2020-06-17 05:06:09 -07:00 committed by TensorFlower Gardener
parent 6128ffea46
commit 9eafb72689
2 changed files with 18 additions and 52 deletions

View File

@ -37,5 +37,4 @@ gen_kernel_library(
"f32", "f32",
"f64", "f64",
], ],
unroll_factors = "4",
) )

View File

@ -45,40 +45,9 @@ Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
return stream_exec->GetKernel(loader_spec, kernel_base.get()); return stream_exec->GetKernel(loader_spec, kernel_base.get());
} }
struct LaunchConfig { class MlirGenerateTanhOp : public OpKernel {
se::BlockDim blockDim;
se::ThreadDim threadDim;
}
LaunchConfig
GetLaunchConfiguration(std::vector<uint64> tile_sizes,
std::vector<uint64> unrolling_factors,
std::vector<uint64> shape) {
LaunchConfig result;
// Ensure the vectors are length 3 and pad with ones.
tile_sizes.resize(3, 1);
unrolling_factors.resize(3, 1);
shape.resize(3, 1);
// We know that the kernel was generated by mapping the three outer-most
// dimensions to x,y,z dimensions. So we only need to compute those.
for (int i = 0; i < 3; ++i) {
// The number of threads is given by the tiling size.
result.threadDim[i] = tile_sizes[i];
// Compute the number of grids. We use ceildiv here as we have to allocate
// an extra thread/block if the division is not even. The kernel contains
// code to handle the boundaries.
int number_of_threads =
(shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i];
int number_of_grids =
(number_of_threads + tile_sizes[i] - 1) / tile_sizes[i];
result.blockDim[i] = number_of_grids;
}
return result;
}
class MlirGeneratedTanhOp : public OpKernel {
public: public:
explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} explicit MlirGenerateTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
auto* stream = ctx->op_device_context()->stream(); auto* stream = ctx->op_device_context()->stream();
@ -119,13 +88,11 @@ class MlirGeneratedTanhOp : public OpKernel {
args.add_argument<int64_t>(inp.NumElements()); args.add_argument<int64_t>(inp.NumElements());
args.add_argument<int64_t>(1); args.add_argument<int64_t>(1);
// This has to be aligned with the configuration that was used when building // TODO(b/158649746): Choose block size and thread dim according to the
// the kernels. See the corresponding build rules in `cubin_headers/BUILD`. // number of input elements. For now, this supports at most 1024 elements.
LaunchCondig config =
GetLaunchConfiguration({256}, {4}, {inp.getNumElements()});
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim, ctx, stream->parent()->Launch(stream, se::ThreadDim(inp.NumElements()),
*kernel, args)); se::BlockDim(1), *kernel, args));
} }
protected: protected:
@ -136,26 +103,26 @@ class MlirGeneratedTanhOp : public OpKernel {
std::mutex mu_; std::mutex mu_;
}; };
class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp { class MlirGenerateTanhF16Op : public MlirGenerateTanhOp {
public: public:
explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx) explicit MlirGenerateTanhF16Op(OpKernelConstruction* ctx)
: MlirGeneratedTanhOp(ctx) { : MlirGenerateTanhOp(ctx) {
cubin_data_ = kTanhF16Kernel; cubin_data_ = kTanhF16Kernel;
} }
}; };
class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp { class MlirGenerateTanhF32Op : public MlirGenerateTanhOp {
public: public:
explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx) explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx)
: MlirGeneratedTanhOp(ctx) { : MlirGenerateTanhOp(ctx) {
cubin_data_ = kTanhF32Kernel; cubin_data_ = kTanhF32Kernel;
} }
}; };
class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp { class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
public: public:
explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx) explicit MlirGenerateTanhF64Op(OpKernelConstruction* ctx)
: MlirGeneratedTanhOp(ctx) { : MlirGenerateTanhOp(ctx) {
cubin_data_ = kTanhF64Kernel; cubin_data_ = kTanhF64Kernel;
} }
}; };
@ -163,11 +130,11 @@ class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp {
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
MlirGeneratedTanhF16Op); MlirGenerateTanhF16Op);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"), Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
MlirGeneratedTanhF32Op); MlirGenerateTanhF32Op);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"), Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
MlirGeneratedTanhF64Op); MlirGenerateTanhF64Op);
} // namespace tensorflow } // namespace tensorflow