Properly configure block and grid dimensions when launching generated kernels.
Also prepare the possibility to specify unrolling. This is not enabled yet because there are some LLVM changes required. PiperOrigin-RevId: 317056534 Change-Id: I3de5dda52d80b528c4bd0026a5e160fda4296c32
This commit is contained in:
parent
f9c4663043
commit
e4af590df8
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -45,9 +46,41 @@ Status CreateKernel(absl::string_view kernel_name, uint64_t num_args,
|
||||
return stream_exec->GetKernel(loader_spec, kernel_base.get());
|
||||
}
|
||||
|
||||
class MlirGenerateTanhOp : public OpKernel {
|
||||
struct LaunchConfig {
|
||||
se::BlockDim blockDim;
|
||||
se::ThreadDim threadDim;
|
||||
};
|
||||
|
||||
LaunchConfig GetLaunchConfiguration(std::vector<uint64> tile_sizes,
|
||||
std::vector<uint64> unrolling_factors,
|
||||
std::vector<uint64> shape) {
|
||||
LaunchConfig result;
|
||||
// Ensure the vectors are length 3 and pad with ones.
|
||||
tile_sizes.resize(3, 1);
|
||||
unrolling_factors.resize(3, 1);
|
||||
shape.resize(3, 1);
|
||||
// The number of threads is given by the tiling size.
|
||||
result.threadDim = se::ThreadDim(tile_sizes[0], tile_sizes[1], tile_sizes[2]);
|
||||
// We know that the kernel was generated by mapping the three outer-most
|
||||
// dimensions to x,y,z dimensions. So we only need to compute those.
|
||||
std::vector<int> block_dims(3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
// Compute the number of grids. We use ceildiv here as we have to allocate
|
||||
// an extra thread/block if the division is not even. The kernel contains
|
||||
// code to handle the boundaries.
|
||||
int number_of_threads =
|
||||
(shape[i] + unrolling_factors[i] - 1) / unrolling_factors[i];
|
||||
int number_of_grids =
|
||||
(number_of_threads + tile_sizes[i] - 1) / tile_sizes[i];
|
||||
block_dims[i] = number_of_grids;
|
||||
}
|
||||
result.blockDim = se::BlockDim(block_dims[0], block_dims[1], block_dims[2]);
|
||||
return result;
|
||||
}
|
||||
|
||||
class MlirGeneratedTanhOp : public OpKernel {
|
||||
public:
|
||||
explicit MlirGenerateTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
explicit MlirGeneratedTanhOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
@ -88,11 +121,13 @@ class MlirGenerateTanhOp : public OpKernel {
|
||||
args.add_argument<int64_t>(inp.NumElements());
|
||||
args.add_argument<int64_t>(1);
|
||||
|
||||
// TODO(b/158649746): Choose block size and thread dim according to the
|
||||
// number of input elements. For now, this supports at most 1024 elements.
|
||||
// This has to be aligned with the configuration that was used when building
|
||||
// the kernels. See the corresponding build rules in `cubin_headers/BUILD`.
|
||||
LaunchConfig config = GetLaunchConfiguration(
|
||||
{256}, {}, {static_cast<uint64>(inp.NumElements())});
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->parent()->Launch(stream, se::ThreadDim(inp.NumElements()),
|
||||
se::BlockDim(1), *kernel, args));
|
||||
ctx, stream->parent()->Launch(stream, config.threadDim, config.blockDim,
|
||||
*kernel, args));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -103,26 +138,26 @@ class MlirGenerateTanhOp : public OpKernel {
|
||||
std::mutex mu_;
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF16Op : public MlirGenerateTanhOp {
|
||||
class MlirGeneratedTanhF16Op : public MlirGeneratedTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF16Op(OpKernelConstruction* ctx)
|
||||
: MlirGenerateTanhOp(ctx) {
|
||||
explicit MlirGeneratedTanhF16Op(OpKernelConstruction* ctx)
|
||||
: MlirGeneratedTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF16Kernel;
|
||||
}
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF32Op : public MlirGenerateTanhOp {
|
||||
class MlirGeneratedTanhF32Op : public MlirGeneratedTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF32Op(OpKernelConstruction* ctx)
|
||||
: MlirGenerateTanhOp(ctx) {
|
||||
explicit MlirGeneratedTanhF32Op(OpKernelConstruction* ctx)
|
||||
: MlirGeneratedTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF32Kernel;
|
||||
}
|
||||
};
|
||||
|
||||
class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
|
||||
class MlirGeneratedTanhF64Op : public MlirGeneratedTanhOp {
|
||||
public:
|
||||
explicit MlirGenerateTanhF64Op(OpKernelConstruction* ctx)
|
||||
: MlirGenerateTanhOp(ctx) {
|
||||
explicit MlirGeneratedTanhF64Op(OpKernelConstruction* ctx)
|
||||
: MlirGeneratedTanhOp(ctx) {
|
||||
cubin_data_ = kTanhF64Kernel;
|
||||
}
|
||||
};
|
||||
@ -130,11 +165,11 @@ class MlirGenerateTanhF64Op : public MlirGenerateTanhOp {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
||||
MlirGenerateTanhF16Op);
|
||||
MlirGeneratedTanhF16Op);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
||||
MlirGenerateTanhF32Op);
|
||||
MlirGeneratedTanhF32Op);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Tanh").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
||||
MlirGenerateTanhF64Op);
|
||||
MlirGeneratedTanhF64Op);
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user