Properly configure the block and grid dimensions when launching generated kernels.
PiperOrigin-RevId: 316867392 Change-Id: I975a9d1e29e954760532a82985dd016707aa6d02
This commit is contained in:
parent
6128ffea46
commit
9eafb72689
@ -37,5 +37,4 @@ gen_kernel_library(
|
|||||||
"f32",
|
"f32",
|
||||||
"f64",
|
"f64",
|
||||||
],
|
],
|
||||||
unroll_factors = "4",
|
|
||||||
)
|
)
|
||||||
|
@ -45,40 +45,9 @@ Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
|
|||||||
return stream_exec->GetKernel(loader_spec, kernel_base.get());
|
return stream_exec->GetKernel(loader_spec, kernel_base.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LaunchConfig {
|
class MlirGenerateTanhOp : public OpKernel {
|
||||||
se::BlockDim blockDim;
|
|
||||||
se::ThreadDim threadDim;
|
|
||||||
}
|
|
||||||
|
|
||||||
LaunchConfig
|
|
||||||
GetLaunchConfiguration(std::vector<uint64> tile_sizes,
|
|
||||||
std::vector<uint64> unrolling_factors,
|
|
||||||
std::vector<uint64> shape) {
|
|
||||||
LaunchConfig result;
|
|
||||||
// Ensure the vectors are length 3 and pad with ones.
|
|
||||||
tile_sizes.resize(3, 1);
|
|
||||||
unrolling_factors.resize(3, 1);
|
|
||||||
shape.resize(3, 1);
|
|
||||||
// We know that the kernel was generated by mapping the three outer-most
|
|
||||||
// dimensions to x,y,z dimensions. So we only need to compute those.
|
|
||||||
for (int i = 0; i < 3; ++i) {
|
|
||||||
// The number of threads is given by the tiling size.
|
|
||||||
result.threadDim[i] = tile_sizes[i];
|
|
||||||
// Compute the number of grids. We use ceildiv here as we have to allocate
|
|
||||||
// an extra thread/block if the division is not even. The kernel contains
|
|
||||||
// code to handle the boundaries.
|
|
||||||
int number_of_threads =
|
|
||||||
(shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i];
|
|
||||||
int number_of_grids =
|
|
||||||
(number_of_threads + tile_sizes[i] - 1) / tile_sizes[i];
|
|
||||||
result.blockDim[i] = number_of_grids;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
class MlirGeneratedTanhOp : public OpKernel {
|
|
||||||
public:
|
public:
|
||||||
explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit MlirGenerateTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto* stream = ctx->op_device_context()->stream();
|
auto* stream = ctx->op_device_context()->stream();
|
||||||
@ -119,13 +88,11 @@ class MlirGeneratedTanhOp : public OpKernel {
|
|||||||
args.add_argument<int64_t>(inp.NumElements());
|
args.add_argument<int64_t>(inp.NumElements());
|
||||||
args.add_argument<int64_t>(1);
|
args.add_argument<int64_t>(1);
|
||||||
|
|
||||||
// This has to be aligned with the configuration that was used when building
|
// TODO(b/158649746): Choose block size and thread dim according to the
|
||||||
// the kernels. See the corresponding build rules in `cubin_headers/BUILD`.
|
// number of input elements. For now, this supports at most 1024 elements.
|
||||||
LaunchCondig config =
|
|
||||||
GetLaunchConfiguration({256}, {4}, {inp.getNumElements()});
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim,
|
ctx, stream->parent()->Launch(stream, se::ThreadDim(inp.NumElements()),
|
||||||
*kernel, args));
|
se::BlockDim(1), *kernel, args));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -136,26 +103,26 @@ class MlirGeneratedTanhOp : public OpKernel {
|
|||||||
std::mutex mu_;
|
std::mutex mu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp {
|
class MlirGenerateTanhF16Op : public MlirGenerateTanhOp {
|
||||||
public:
|
public:
|
||||||
explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx)
|
explicit MlirGenerateTanhF16Op(OpKernelConstruction* ctx)
|
||||||
: MlirGeneratedTanhOp(ctx) {
|
: MlirGenerateTanhOp(ctx) {
|
||||||
cubin_data_ = kTanhF16Kernel;
|
cubin_data_ = kTanhF16Kernel;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp {
|
class MlirGenerateTanhF32Op : public MlirGenerateTanhOp {
|
||||||
public:
|
public:
|
||||||
explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx)
|
explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx)
|
||||||
: MlirGeneratedTanhOp(ctx) {
|
: MlirGenerateTanhOp(ctx) {
|
||||||
cubin_data_ = kTanhF32Kernel;
|
cubin_data_ = kTanhF32Kernel;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp {
|
class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
|
||||||
public:
|
public:
|
||||||
explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx)
|
explicit MlirGenerateTanhF64Op(OpKernelConstruction* ctx)
|
||||||
: MlirGeneratedTanhOp(ctx) {
|
: MlirGenerateTanhOp(ctx) {
|
||||||
cubin_data_ = kTanhF64Kernel;
|
cubin_data_ = kTanhF64Kernel;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -163,11 +130,11 @@ class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp {
|
|||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
||||||
MlirGeneratedTanhF16Op);
|
MlirGenerateTanhF16Op);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
||||||
MlirGeneratedTanhF32Op);
|
MlirGenerateTanhF32Op);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
||||||
MlirGeneratedTanhF64Op);
|
MlirGenerateTanhF64Op);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user