Merge pull request #36384 from nouiz:xla_reduction_branch

PiperOrigin-RevId: 293563326
Change-Id: I8ee92fc645e3c443c6f2ff04a9db89a33208df5d
This commit is contained in:
TensorFlower Gardener 2020-02-06 04:08:06 -08:00
commit af49800eda
4 changed files with 112 additions and 120 deletions

View File

@ -1908,23 +1908,47 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
llvm::ConstantInt::get(index_ty, x_num_steps)); llvm::ConstantInt::get(index_ty, x_num_steps));
} }
void IrEmitterUnnested::EmitTile( // Emits code to process up to
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
// given `emit_elem_function` is the function to emit code to process one
// element, `y` and `x` are the intra-tile coordinates for the first element
// to process, and `index` is the index for the origin of the tile. Information
// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
// bounds check to ensure that each processed element is within the boundary
// defined by `tile_width` and `tile_height`.
//
// Pseudocode:
//
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
// if (dilated) {
// x_loc = x + j * num_threads_x;
// } else {
// x_loc = x * (tile_size_x / num_threads_x) + j;
// }
//
// if (x_loc < tile_width) {
// emit_elem_function(y + y_loc, x_loc);
// }
// }
// }
//
static void EmitTile(
const KernelMappingScheme& mapping_scheme, const KernelMappingScheme& mapping_scheme,
const IrArray::Index& tile_origin_index, const string& loop_name, const IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::Value* thread_id_y, KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
llvm::Type* index_ty = tile_width->getType(); llvm::Type* index_ty = tile_width->getType();
auto constant = [&](int64 val) { auto constant = [&](int64 val) {
return llvm::ConstantInt::get(index_ty, val); return llvm::ConstantInt::get(index_ty, val);
}; };
int64 num_threads_x = mapping_scheme.GetNumThreadsX(); int64 num_threads_x = mapping_scheme.GetNumThreadsX();
llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY()); int64 num_threads_y = mapping_scheme.GetNumThreadsY();
int64 tile_size_x = mapping_scheme.GetTileSizeX(); int64 tile_size_x = mapping_scheme.GetTileSizeX();
int64 x_num_steps = tile_size_x / num_threads_x; int64 x_num_steps = tile_size_x / num_threads_x;
llvm::Value* start_offset_x = llvm::Value* start_offset_x = GetStartOffsetX(mapping_scheme, x, index_ty, b);
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
// Using dilated mapping scheme, each thread steps with a stride of number // Using dilated mapping scheme, each thread steps with a stride of number
// of threads. // of threads.
@ -1933,42 +1957,33 @@ void IrEmitterUnnested::EmitTile(
int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1; int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
IrArray::Index source_idx = IrArray::Index source_idx =
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_); tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, b);
auto ceil_of_ratio = [&](llvm::Value* a, llvm::Value* b) { // True when all threads will always execute all instructions.
return b_.CreateUDiv(b_.CreateAdd(b_.CreateAdd(a, b), constant(-1)), b); // So we do not need to emit condition.
}; bool always_full_tile = mapping_scheme.GetDimsInElems()[2] % tile_size_x == 0;
// The outer loop below is simply doing: ksl->For(loop_name + "_y_in_tile",
// /*start=*/y,
// for (int y_loc=thread_id_y; y_loc<tile_height; y_loc+=num_threads_y) /*end=*/tile_height,
// /*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
// IrArray::Index source_idx_y =
// However, in order to avoid an LLVM optimization triggering the ptxas bug, source_idx.AddOffsetToDim(y_loc, kDimY, b);
// we write this loop in a convoluted way:
//
// y_bound = ceil_of_ratio(tile_height - thread_id_y, num_threads_y)
// for (int y_indvar=0; y_indvar<y_bound; y_indvar+=1)
// y_loc = thread_id_y + y_indvar * num_threads_y
//
// TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
// workaround.
ksl->For(
loop_name + "_y_in_tile",
/*start=*/constant(0),
/*end=*/
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
llvm::Value* y_loc =
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
for (int64 j = 0; j < x_num_steps; j++) { for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc = llvm::Value* x_loc =
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc"); b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x = IrArray::Index source_idx_x =
source_idx.AddOffsetToDim(y_loc, kDimY, &b_) source_idx_y.AddOffsetToDim(constant(j * step_x), kDimX, b);
.AddOffsetToDim(constant(j * step_x), kDimX, &b_); // The if-statement below always evaluates to true for the blocks
ksl->If(loop_name + "_x_in_tile", b_.CreateICmpULT(x_loc, tile_width), // where the entire processed tile fits within the input buffer.
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); }); if (!always_full_tile) {
ksl->If(loop_name + "_x_in_tile",
b->CreateICmpULT(x_loc, tile_width), [&] {
emit_elem_function(source_idx_x, y_loc, x_loc, j);
});
} else {
emit_elem_function(source_idx_x, y_loc, x_loc, j);
}
} }
}); });
} }
@ -2608,7 +2623,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
// tile[y, x] = input[index] // tile[y, x] = input[index]
// Note that tile_width and tile_height are flipped here because we // Note that tile_width and tile_height are flipped here because we
// are reading a transposed tile. // are reading a transposed tile.
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x, xla::gpu::EmitTile(
mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
tile_width, tile_height, tile_width, tile_height,
[&](const IrArray::Index& index, llvm::Value* y_loc, [&](const IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm::Value* x_loc, int64 /*x_iter_num*/) {
@ -2617,8 +2633,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
param_in_reduced_shape_arrays[id]; param_in_reduced_shape_arrays[id];
llvm::Value* shmem_buffer = param_shmem_buffers[id]; llvm::Value* shmem_buffer = param_shmem_buffers[id];
llvm::Value* zero = llvm::Value* zero = llvm::ConstantInt::get(index_type, 0);
llvm::ConstantInt::get(index_type, 0);
// TODO(jlebar): Add AA metadata to this store. Tile // TODO(jlebar): Add AA metadata to this store. Tile
// buffers are global variables, so LLVM can't infer much // buffers are global variables, so LLVM can't infer much
// about it. // about it.
@ -2633,8 +2648,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
EmitSyncThreads(); EmitSyncThreads();
} }
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height, xla::gpu::EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x,
tile_width, element_generator); tile_height, tile_width, element_generator);
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1; bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
// If a tile block contains multiple tiles and shared memory buffers are // If a tile block contains multiple tiles and shared memory buffers are
@ -3059,8 +3074,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index, [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height, const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) { llvm::Value* tile_width, KernelSupportLibrary* ksl) {
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, xla::gpu::EmitTile(reduction_info.GetKernelMappingScheme(), index,
y, x, tile_height, tile_width, emit_reduction_tile); loop_name, ksl, &b_, y, x, tile_height, tile_width,
emit_reduction_tile);
}); });
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info, EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
reduce_instructions, reduction_output_shape_indices, reduce_instructions, reduction_output_shape_indices,

View File

@ -235,39 +235,6 @@ class IrEmitterUnnested : public IrEmitter,
const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty, const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
const TileElementGenerator& tile_element_generator); const TileElementGenerator& tile_element_generator);
// Emits code to process up to
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
// given `emit_elem_function` is the function to emit code to process one
// element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for
// the first element to process, and `index` is the index for the origin of
// the tile. Information about tile_size_x/y and num_threads_x/y are stored in
// `mapping_scheme`. Emits bounds check to ensure that each processed element
// is within the boundary defined by `tile_width` and `tile_height`.
//
// Pseudocode:
//
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
// if (dilated) {
// x_loc = x + j * num_threads_x;
// } else {
// x_loc = x * (tile_size_x / num_threads_x) + j;
// }
//
// if (x_loc < tile_width) {
// emit_elem_function(y + y_loc, x_loc);
// }
// }
// }
//
void EmitTile(
const KernelMappingScheme& mapping_scheme,
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
llvm::Value* thread_id_x, llvm::Value* tile_height,
llvm::Value* tile_width,
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
// Emits code to process a tensor element in a tile for the given kCopy HLO // Emits code to process a tensor element in a tile for the given kCopy HLO
// that performs a 0-2-1 transpose. // that performs a 0-2-1 transpose.
void EmitTileElementForCopy( void EmitTileElementForCopy(

View File

@ -60,11 +60,15 @@ Status KernelSupportLibrary::IfWithStatus(
absl::string_view name, llvm::Value* condition, absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator, const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator) { const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(condition, name, b_,
/*emit_else=*/false_block_generator != nullptr);
b_->SetInsertPoint(&if_data.true_block->back()); b_->SetInsertPoint(&if_data.true_block->back());
TF_RETURN_IF_ERROR(true_block_generator()); TF_RETURN_IF_ERROR(true_block_generator());
if (false_block_generator != nullptr) {
b_->SetInsertPoint(&if_data.false_block->back()); b_->SetInsertPoint(&if_data.false_block->back());
TF_RETURN_IF_ERROR(false_block_generator()); TF_RETURN_IF_ERROR(false_block_generator());
}
llvm_ir::SetToLastInsertPoint(if_data.after_block, b_); llvm_ir::SetToLastInsertPoint(if_data.after_block, b_);
return Status::OK(); return Status::OK();
} }

View File

@ -203,12 +203,11 @@ class KernelSupportLibrary {
// `true_block_generator()`; // `true_block_generator()`;
// else // else
// `false_block_generator()`; // `false_block_generator()`;
// The else is skipped if false_block_generator is null.
Status IfWithStatus( Status IfWithStatus(
absl::string_view name, llvm::Value* condition, absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator, const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator = []() -> Status { const std::function<Status()>& false_block_generator = nullptr);
return Status::OK();
});
Status IfWithStatus( Status IfWithStatus(
llvm::Value* condition, llvm::Value* condition,
@ -220,16 +219,16 @@ class KernelSupportLibrary {
false_block_generator); false_block_generator);
} }
void If( void If(llvm::Value* condition,
llvm::Value* condition, const std::function<void()>& true_block_generator, const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {}) { const std::function<void()>& false_block_generator = nullptr) {
If("", condition, true_block_generator, false_block_generator); If("", condition, true_block_generator, false_block_generator);
} }
void If( void If(absl::string_view name, llvm::Value* condition,
absl::string_view name, llvm::Value* condition,
const std::function<void()>& true_block_generator, const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {}) { const std::function<void()>& false_block_generator = nullptr) {
if (false_block_generator != nullptr) {
TF_CHECK_OK(IfWithStatus( TF_CHECK_OK(IfWithStatus(
name, condition, name, condition,
[&]() { [&]() {
@ -240,6 +239,12 @@ class KernelSupportLibrary {
false_block_generator(); false_block_generator();
return Status::OK(); return Status::OK();
})); }));
} else {
TF_CHECK_OK(IfWithStatus(name, condition, [&]() {
true_block_generator();
return Status::OK();
}));
}
} }
using ArgumentVector = absl::Span<llvm::Value* const>; using ArgumentVector = absl::Span<llvm::Value* const>;