Automated rollback of commit af49800eda
.
PiperOrigin-RevId: 293603359 Change-Id: I7f44f8566b29458baffe30589c3d52f5caeddc33
This commit is contained in:
parent
db7d83f9ca
commit
6a72b2d95d
@ -1908,47 +1908,23 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
|
||||
llvm::ConstantInt::get(index_ty, x_num_steps));
|
||||
}
|
||||
|
||||
// Emits code to process up to
|
||||
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
|
||||
// given `emit_elem_function` is the function to emit code to process one
|
||||
// element, `y` and `x` are the intra-tile coordinates for the first element
|
||||
// to process, and `index` is the index for the origin of the tile. Information
|
||||
// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
|
||||
// bounds check to ensure that each processed element is within the boundary
|
||||
// defined by `tile_width` and `tile_height`.
|
||||
//
|
||||
// Pseudocode:
|
||||
//
|
||||
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
|
||||
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
|
||||
// if (dilated) {
|
||||
// x_loc = x + j * num_threads_x;
|
||||
// } else {
|
||||
// x_loc = x * (tile_size_x / num_threads_x) + j;
|
||||
// }
|
||||
//
|
||||
// if (x_loc < tile_width) {
|
||||
// emit_elem_function(y + y_loc, x_loc);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
static void EmitTile(
|
||||
void IrEmitterUnnested::EmitTile(
|
||||
const KernelMappingScheme& mapping_scheme,
|
||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
|
||||
llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
||||
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||
llvm::Type* index_ty = tile_width->getType();
|
||||
auto constant = [&](int64 val) {
|
||||
return llvm::ConstantInt::get(index_ty, val);
|
||||
};
|
||||
int64 num_threads_x = mapping_scheme.GetNumThreadsX();
|
||||
int64 num_threads_y = mapping_scheme.GetNumThreadsY();
|
||||
llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
|
||||
int64 tile_size_x = mapping_scheme.GetTileSizeX();
|
||||
|
||||
int64 x_num_steps = tile_size_x / num_threads_x;
|
||||
llvm::Value* start_offset_x = GetStartOffsetX(mapping_scheme, x, index_ty, b);
|
||||
llvm::Value* start_offset_x =
|
||||
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
|
||||
|
||||
// Using dilated mapping scheme, each thread steps with a stride of number
|
||||
// of threads.
|
||||
@ -1957,35 +1933,44 @@ static void EmitTile(
|
||||
int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
|
||||
|
||||
IrArray::Index source_idx =
|
||||
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, b);
|
||||
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
|
||||
|
||||
// True when all threads will always execute all instructions.
|
||||
// So we do not need to emit condition.
|
||||
bool always_full_tile = mapping_scheme.GetDimsInElems()[2] % tile_size_x == 0;
|
||||
auto ceil_of_ratio = [&](llvm::Value* a, llvm::Value* b) {
|
||||
return b_.CreateUDiv(b_.CreateAdd(b_.CreateAdd(a, b), constant(-1)), b);
|
||||
};
|
||||
|
||||
ksl->For(loop_name + "_y_in_tile",
|
||||
/*start=*/y,
|
||||
/*end=*/tile_height,
|
||||
/*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
|
||||
IrArray::Index source_idx_y =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, b);
|
||||
for (int64 j = 0; j < x_num_steps; j++) {
|
||||
llvm::Value* x_loc =
|
||||
b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx_y.AddOffsetToDim(constant(j * step_x), kDimX, b);
|
||||
// The if-statement below always evaluates to true for the blocks
|
||||
// where the entire processed tile fits within the input buffer.
|
||||
if (!always_full_tile) {
|
||||
ksl->If(loop_name + "_x_in_tile",
|
||||
b->CreateICmpULT(x_loc, tile_width), [&] {
|
||||
emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
});
|
||||
} else {
|
||||
emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
}
|
||||
}
|
||||
});
|
||||
// The outer loop below is simply doing:
|
||||
//
|
||||
// for (int y_loc=thread_id_y; y_loc<tile_height; y_loc+=num_threads_y)
|
||||
//
|
||||
//
|
||||
// However, in order to avoid an LLVM optimization triggering the ptxas bug,
|
||||
// we write this loop in a convoluted way:
|
||||
//
|
||||
// y_bound = ceil_of_ratio(tile_height - thread_id_y, num_threads_y)
|
||||
// for (int y_indvar=0; y_indvar<y_bound; y_indvar+=1)
|
||||
// y_loc = thread_id_y + y_indvar * num_threads_y
|
||||
//
|
||||
// TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
|
||||
// workaround.
|
||||
ksl->For(
|
||||
loop_name + "_y_in_tile",
|
||||
/*start=*/constant(0),
|
||||
/*end=*/
|
||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
|
||||
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
||||
llvm::Value* y_loc =
|
||||
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
||||
for (int64 j = 0; j < x_num_steps; j++) {
|
||||
llvm::Value* x_loc =
|
||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
||||
ksl->If(loop_name + "_x_in_tile", b_.CreateICmpULT(x_loc, tile_width),
|
||||
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kCopy HLO that
|
||||
@ -2623,33 +2608,33 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
// tile[y, x] = input[index]
|
||||
// Note that tile_width and tile_height are flipped here because we
|
||||
// are reading a transposed tile.
|
||||
xla::gpu::EmitTile(
|
||||
mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
|
||||
tile_width, tile_height,
|
||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||
for (int64 id : tiled_param_ids) {
|
||||
IrArray& input_in_logical_shape =
|
||||
param_in_reduced_shape_arrays[id];
|
||||
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x,
|
||||
tile_width, tile_height,
|
||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||
for (int64 id : tiled_param_ids) {
|
||||
IrArray& input_in_logical_shape =
|
||||
param_in_reduced_shape_arrays[id];
|
||||
|
||||
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
||||
llvm::Value* zero = llvm::ConstantInt::get(index_type, 0);
|
||||
// TODO(jlebar): Add AA metadata to this store. Tile
|
||||
// buffers are global variables, so LLVM can't infer much
|
||||
// about it.
|
||||
Store(input_in_logical_shape.EmitReadArrayElement(
|
||||
index, &b_, "input_element"),
|
||||
GEP(shmem_buffer, {zero, y_loc, x_loc}));
|
||||
}
|
||||
});
|
||||
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
||||
llvm::Value* zero =
|
||||
llvm::ConstantInt::get(index_type, 0);
|
||||
// TODO(jlebar): Add AA metadata to this store. Tile
|
||||
// buffers are global variables, so LLVM can't infer much
|
||||
// about it.
|
||||
Store(input_in_logical_shape.EmitReadArrayElement(
|
||||
index, &b_, "input_element"),
|
||||
GEP(shmem_buffer, {zero, y_loc, x_loc}));
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for all threads to reach this point using `__syncthreads` in
|
||||
// CUDA.
|
||||
EmitSyncThreads();
|
||||
}
|
||||
|
||||
xla::gpu::EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x,
|
||||
tile_height, tile_width, element_generator);
|
||||
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height,
|
||||
tile_width, element_generator);
|
||||
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
|
||||
|
||||
// If a tile block contains multiple tiles and shared memory buffers are
|
||||
@ -3074,9 +3059,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
xla::gpu::EmitTile(reduction_info.GetKernelMappingScheme(), index,
|
||||
loop_name, ksl, &b_, y, x, tile_height, tile_width,
|
||||
emit_reduction_tile);
|
||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
||||
y, x, tile_height, tile_width, emit_reduction_tile);
|
||||
});
|
||||
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||
reduce_instructions, reduction_output_shape_indices,
|
||||
|
@ -235,6 +235,39 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
|
||||
const TileElementGenerator& tile_element_generator);
|
||||
|
||||
// Emits code to process up to
|
||||
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
|
||||
// given `emit_elem_function` is the function to emit code to process one
|
||||
// element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for
|
||||
// the first element to process, and `index` is the index for the origin of
|
||||
// the tile. Information about tile_size_x/y and num_threads_x/y are stored in
|
||||
// `mapping_scheme`. Emits bounds check to ensure that each processed element
|
||||
// is within the boundary defined by `tile_width` and `tile_height`.
|
||||
//
|
||||
// Pseudocode:
|
||||
//
|
||||
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
|
||||
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
|
||||
// if (dilated) {
|
||||
// x_loc = x + j * num_threads_x;
|
||||
// } else {
|
||||
// x_loc = x * (tile_size_x / num_threads_x) + j;
|
||||
// }
|
||||
//
|
||||
// if (x_loc < tile_width) {
|
||||
// emit_elem_function(y + y_loc, x_loc);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
void EmitTile(
|
||||
const KernelMappingScheme& mapping_scheme,
|
||||
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
||||
llvm::Value* thread_id_x, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width,
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
||||
// that performs a 0-2-1 transpose.
|
||||
void EmitTileElementForCopy(
|
||||
|
@ -60,15 +60,11 @@ Status KernelSupportLibrary::IfWithStatus(
|
||||
absl::string_view name, llvm::Value* condition,
|
||||
const std::function<Status()>& true_block_generator,
|
||||
const std::function<Status()>& false_block_generator) {
|
||||
llvm_ir::LlvmIfData if_data =
|
||||
llvm_ir::EmitIfThenElse(condition, name, b_,
|
||||
/*emit_else=*/false_block_generator != nullptr);
|
||||
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_);
|
||||
b_->SetInsertPoint(&if_data.true_block->back());
|
||||
TF_RETURN_IF_ERROR(true_block_generator());
|
||||
if (false_block_generator != nullptr) {
|
||||
b_->SetInsertPoint(&if_data.false_block->back());
|
||||
TF_RETURN_IF_ERROR(false_block_generator());
|
||||
}
|
||||
b_->SetInsertPoint(&if_data.false_block->back());
|
||||
TF_RETURN_IF_ERROR(false_block_generator());
|
||||
llvm_ir::SetToLastInsertPoint(if_data.after_block, b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -203,11 +203,12 @@ class KernelSupportLibrary {
|
||||
// `true_block_generator()`;
|
||||
// else
|
||||
// `false_block_generator()`;
|
||||
// The else is skipped if false_block_generator is null.
|
||||
Status IfWithStatus(
|
||||
absl::string_view name, llvm::Value* condition,
|
||||
const std::function<Status()>& true_block_generator,
|
||||
const std::function<Status()>& false_block_generator = nullptr);
|
||||
const std::function<Status()>& false_block_generator = []() -> Status {
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
Status IfWithStatus(
|
||||
llvm::Value* condition,
|
||||
@ -219,32 +220,26 @@ class KernelSupportLibrary {
|
||||
false_block_generator);
|
||||
}
|
||||
|
||||
void If(llvm::Value* condition,
|
||||
const std::function<void()>& true_block_generator,
|
||||
const std::function<void()>& false_block_generator = nullptr) {
|
||||
void If(
|
||||
llvm::Value* condition, const std::function<void()>& true_block_generator,
|
||||
const std::function<void()>& false_block_generator = []() {}) {
|
||||
If("", condition, true_block_generator, false_block_generator);
|
||||
}
|
||||
|
||||
void If(absl::string_view name, llvm::Value* condition,
|
||||
const std::function<void()>& true_block_generator,
|
||||
const std::function<void()>& false_block_generator = nullptr) {
|
||||
if (false_block_generator != nullptr) {
|
||||
TF_CHECK_OK(IfWithStatus(
|
||||
name, condition,
|
||||
[&]() {
|
||||
true_block_generator();
|
||||
return Status::OK();
|
||||
},
|
||||
[&]() {
|
||||
false_block_generator();
|
||||
return Status::OK();
|
||||
}));
|
||||
} else {
|
||||
TF_CHECK_OK(IfWithStatus(name, condition, [&]() {
|
||||
true_block_generator();
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
void If(
|
||||
absl::string_view name, llvm::Value* condition,
|
||||
const std::function<void()>& true_block_generator,
|
||||
const std::function<void()>& false_block_generator = []() {}) {
|
||||
TF_CHECK_OK(IfWithStatus(
|
||||
name, condition,
|
||||
[&]() {
|
||||
true_block_generator();
|
||||
return Status::OK();
|
||||
},
|
||||
[&]() {
|
||||
false_block_generator();
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
|
||||
using ArgumentVector = absl::Span<llvm::Value* const>;
|
||||
|
Loading…
Reference in New Issue
Block a user