Automated rollback of commit af49800eda.

PiperOrigin-RevId: 293603359
Change-Id: I7f44f8566b29458baffe30589c3d52f5caeddc33
This commit is contained in:
George Karpenkov 2020-02-06 09:00:52 -08:00 committed by TensorFlower Gardener
parent db7d83f9ca
commit 6a72b2d95d
4 changed files with 120 additions and 112 deletions

View File

@ -1908,47 +1908,23 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
llvm::ConstantInt::get(index_ty, x_num_steps));
}
// Emits code to process up to
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
// given `emit_elem_function` is the function to emit code to process one
// element, `y` and `x` are the intra-tile coordinates for the first element
// to process, and `index` is the index for the origin of the tile. Information
// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
// bounds check to ensure that each processed element is within the boundary
// defined by `tile_width` and `tile_height`.
//
// Pseudocode:
//
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
// if (dilated) {
// x_loc = x + j * num_threads_x;
// } else {
// x_loc = x * (tile_size_x / num_threads_x) + j;
// }
//
// if (x_loc < tile_width) {
// emit_elem_function(y + y_loc, x_loc);
// }
// }
// }
//
static void EmitTile(
void IrEmitterUnnested::EmitTile(
const KernelMappingScheme& mapping_scheme,
const IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width,
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
llvm::Type* index_ty = tile_width->getType();
auto constant = [&](int64 val) {
return llvm::ConstantInt::get(index_ty, val);
};
int64 num_threads_x = mapping_scheme.GetNumThreadsX();
int64 num_threads_y = mapping_scheme.GetNumThreadsY();
llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
int64 tile_size_x = mapping_scheme.GetTileSizeX();
int64 x_num_steps = tile_size_x / num_threads_x;
llvm::Value* start_offset_x = GetStartOffsetX(mapping_scheme, x, index_ty, b);
llvm::Value* start_offset_x =
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
// Using dilated mapping scheme, each thread steps with a stride of number
// of threads.
@ -1957,35 +1933,44 @@ static void EmitTile(
int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
IrArray::Index source_idx =
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, b);
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
// True when all threads will always execute all instructions.
// So we do not need to emit condition.
bool always_full_tile = mapping_scheme.GetDimsInElems()[2] % tile_size_x == 0;
auto ceil_of_ratio = [&](llvm::Value* a, llvm::Value* b) {
return b_.CreateUDiv(b_.CreateAdd(b_.CreateAdd(a, b), constant(-1)), b);
};
ksl->For(loop_name + "_y_in_tile",
/*start=*/y,
/*end=*/tile_height,
/*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
IrArray::Index source_idx_y =
source_idx.AddOffsetToDim(y_loc, kDimY, b);
for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc =
b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x =
source_idx_y.AddOffsetToDim(constant(j * step_x), kDimX, b);
// The if-statement below always evaluates to true for the blocks
// where the entire processed tile fits within the input buffer.
if (!always_full_tile) {
ksl->If(loop_name + "_x_in_tile",
b->CreateICmpULT(x_loc, tile_width), [&] {
emit_elem_function(source_idx_x, y_loc, x_loc, j);
});
} else {
emit_elem_function(source_idx_x, y_loc, x_loc, j);
}
}
});
// The outer loop below is simply doing:
//
// for (int y_loc=thread_id_y; y_loc<tile_height; y_loc+=num_threads_y)
//
//
// However, in order to avoid an LLVM optimization triggering the ptxas bug,
// we write this loop in a convoluted way:
//
// y_bound = ceil_of_ratio(tile_height - thread_id_y, num_threads_y)
// for (int y_indvar=0; y_indvar<y_bound; y_indvar+=1)
// y_loc = thread_id_y + y_indvar * num_threads_y
//
// TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
// workaround.
ksl->For(
loop_name + "_y_in_tile",
/*start=*/constant(0),
/*end=*/
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
llvm::Value* y_loc =
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc =
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x =
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
ksl->If(loop_name + "_x_in_tile", b_.CreateICmpULT(x_loc, tile_width),
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
}
});
}
// Emits code to process a tensor element in a tile for the given kCopy HLO that
@ -2623,33 +2608,33 @@ void IrEmitterUnnested::EmitHlo021Tile(
// tile[y, x] = input[index]
// Note that tile_width and tile_height are flipped here because we
// are reading a transposed tile.
xla::gpu::EmitTile(
mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
tile_width, tile_height,
[&](const IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 /*x_iter_num*/) {
for (int64 id : tiled_param_ids) {
IrArray& input_in_logical_shape =
param_in_reduced_shape_arrays[id];
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x,
tile_width, tile_height,
[&](const IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 /*x_iter_num*/) {
for (int64 id : tiled_param_ids) {
IrArray& input_in_logical_shape =
param_in_reduced_shape_arrays[id];
llvm::Value* shmem_buffer = param_shmem_buffers[id];
llvm::Value* zero = llvm::ConstantInt::get(index_type, 0);
// TODO(jlebar): Add AA metadata to this store. Tile
// buffers are global variables, so LLVM can't infer much
// about it.
Store(input_in_logical_shape.EmitReadArrayElement(
index, &b_, "input_element"),
GEP(shmem_buffer, {zero, y_loc, x_loc}));
}
});
llvm::Value* shmem_buffer = param_shmem_buffers[id];
llvm::Value* zero =
llvm::ConstantInt::get(index_type, 0);
// TODO(jlebar): Add AA metadata to this store. Tile
// buffers are global variables, so LLVM can't infer much
// about it.
Store(input_in_logical_shape.EmitReadArrayElement(
index, &b_, "input_element"),
GEP(shmem_buffer, {zero, y_loc, x_loc}));
}
});
// Wait for all threads to reach this point using `__syncthreads` in
// CUDA.
EmitSyncThreads();
}
xla::gpu::EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x,
tile_height, tile_width, element_generator);
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height,
tile_width, element_generator);
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
// If a tile block contains multiple tiles and shared memory buffers are
@ -3074,9 +3059,8 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
xla::gpu::EmitTile(reduction_info.GetKernelMappingScheme(), index,
loop_name, ksl, &b_, y, x, tile_height, tile_width,
emit_reduction_tile);
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
y, x, tile_height, tile_width, emit_reduction_tile);
});
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
reduce_instructions, reduction_output_shape_indices,

View File

@ -235,6 +235,39 @@ class IrEmitterUnnested : public IrEmitter,
const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
const TileElementGenerator& tile_element_generator);
// Emits code to process up to
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
// given `emit_elem_function` is the function to emit code to process one
// element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for
// the first element to process, and `index` is the index for the origin of
// the tile. Information about tile_size_x/y and num_threads_x/y are stored in
// `mapping_scheme`. Emits bounds check to ensure that each processed element
// is within the boundary defined by `tile_width` and `tile_height`.
//
// Pseudocode:
//
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
// if (dilated) {
// x_loc = x + j * num_threads_x;
// } else {
// x_loc = x * (tile_size_x / num_threads_x) + j;
// }
//
// if (x_loc < tile_width) {
// emit_elem_function(y + y_loc, x_loc);
// }
// }
// }
//
void EmitTile(
const KernelMappingScheme& mapping_scheme,
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
llvm::Value* thread_id_x, llvm::Value* tile_height,
llvm::Value* tile_width,
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
// Emits code to process a tensor element in a tile for the given kCopy HLO
// that performs a 0-2-1 transpose.
void EmitTileElementForCopy(

View File

@ -60,15 +60,11 @@ Status KernelSupportLibrary::IfWithStatus(
absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(condition, name, b_,
/*emit_else=*/false_block_generator != nullptr);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_);
b_->SetInsertPoint(&if_data.true_block->back());
TF_RETURN_IF_ERROR(true_block_generator());
if (false_block_generator != nullptr) {
b_->SetInsertPoint(&if_data.false_block->back());
TF_RETURN_IF_ERROR(false_block_generator());
}
b_->SetInsertPoint(&if_data.false_block->back());
TF_RETURN_IF_ERROR(false_block_generator());
llvm_ir::SetToLastInsertPoint(if_data.after_block, b_);
return Status::OK();
}

View File

@ -203,11 +203,12 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
// The else is skipped if false_block_generator is null.
Status IfWithStatus(
absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator = nullptr);
const std::function<Status()>& false_block_generator = []() -> Status {
return Status::OK();
});
Status IfWithStatus(
llvm::Value* condition,
@ -219,32 +220,26 @@ class KernelSupportLibrary {
false_block_generator);
}
void If(llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = nullptr) {
void If(
llvm::Value* condition, const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {}) {
If("", condition, true_block_generator, false_block_generator);
}
void If(absl::string_view name, llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = nullptr) {
if (false_block_generator != nullptr) {
TF_CHECK_OK(IfWithStatus(
name, condition,
[&]() {
true_block_generator();
return Status::OK();
},
[&]() {
false_block_generator();
return Status::OK();
}));
} else {
TF_CHECK_OK(IfWithStatus(name, condition, [&]() {
true_block_generator();
return Status::OK();
}));
}
void If(
absl::string_view name, llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {}) {
TF_CHECK_OK(IfWithStatus(
name, condition,
[&]() {
true_block_generator();
return Status::OK();
},
[&]() {
false_block_generator();
return Status::OK();
}));
}
using ArgumentVector = absl::Span<llvm::Value* const>;