[XLA:GPU] Miscellaneous clean up for 0-2-1 transpose related code.

We use the terms reduced shape and unreduced shape to refer to the logical
shapes and the original shapes of a 0-2-1 tranpose. Since reduced shape and
unreduced shape could also refer to the result shape and source shape of a
reduction operation, the purpose of this CL is mainly to change the 0-2-1
transpose related code outside reduction implementation to use the words
normalized/unnormalized instead of reduced/unreduced. The reduction
implementation will be fixed in another CL that migrates the implementation to
use the kernel mapping scheme.

PiperOrigin-RevId: 221575772
This commit is contained in:
Bixia Zheng 2018-11-14 23:56:51 -08:00 committed by TensorFlower Gardener
parent 8dd83e34a0
commit 387a062dd3
3 changed files with 58 additions and 63 deletions

View File

@ -3375,7 +3375,7 @@ void IrEmitterUnnested::EmitTileElementForFusion(
fused_emitter.SetTiledParameterInfo(tiled_param_info);
TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
IrArray::Index untiled_index =
kernel_info->GetKernelMappingScheme()->GetReshapedOutputIndex(
kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
index, output_arrays[0].GetShape());
const llvm_ir::ElementGenerator& output_generator =
fused_emitter.GetRootGenerator();

View File

@ -52,6 +52,29 @@ Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) {
return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
dimensions);
}
// Given an index for a shape, return the equivalent new index if the shape is
// reshaped to another shape.
IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape,
const Shape& reshaped_shape,
llvm::IRBuilder<>* b) {
auto bounds = shape.dimensions();
auto minor_to_major = shape.layout().minor_to_major();
llvm::Value* linear_index = index.GetConstantWithIndexType(0);
int64 multiplier = 1;
for (int i = 0; i < index.size(); ++i) {
int64 dim = minor_to_major[i];
llvm::Value* addend = b->CreateMul(
index[dim], index.GetConstantWithIndexType(multiplier), "linearizing",
/*HasNUW=*/true, /*HasNSW=*/true);
linear_index = b->CreateAdd(linear_index, addend, "",
/*HasNUW=*/true, /*HasNSW=*/true);
multiplier *= bounds[dim];
}
return IrArray::Index(linear_index, reshaped_shape, b);
}
} // namespace
absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
@ -60,28 +83,30 @@ absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
return absl::nullopt;
}
std::vector<int64> perm(a.dimensions().size());
{
auto layout_a_orig = LayoutUtil::MinorToMajor(a);
std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
auto layout_b_orig = LayoutUtil::MinorToMajor(b);
std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
for (size_t i = 0; i < perm.size(); ++i) {
perm[i] = PositionInContainer(layout_b, layout_a[i]);
}
std::vector<int64> permutation(a.dimensions().size());
absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
minor_to_major_a.rend());
absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
minor_to_major_b.rend());
for (size_t i = 0; i < permutation.size(); ++i) {
permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
}
auto segs = ConsecutiveSegments(perm);
if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) {
Shape norm_a =
std::vector<size_t> segments = ConsecutiveSegments(permutation);
if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
Shape descending_layout_shape =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
Shape reduced_a = MergeDimensions(segs, norm_a);
auto reduced_a_dims = reduced_a.dimensions();
Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
absl::Span<const int64> normalized_dims =
AsInt64Slice(normalized_shape.dimensions());
std::vector<int64> dims_021;
if (2 == segs.size()) {
if (2 == segments.size()) {
// The logical component-0 is of size one.
dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]};
dims_021 = {1, normalized_dims[1], normalized_dims[0]};
} else {
dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]};
dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
}
return dims_021;
@ -90,29 +115,6 @@ absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
return absl::nullopt;
}
IrArray::Index GetUnreducedOutputIndex(
const IrArray::Index& reduced_output_index,
const Shape& reduced_output_shape, const Shape& unreduced_output_shape,
llvm::IRBuilder<>* b) {
auto bounds = reduced_output_shape.dimensions();
auto minor_to_major = reduced_output_shape.layout().minor_to_major();
llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0);
int64 multiplier = 1;
for (int i = 0; i < reduced_output_index.size(); ++i) {
int64 dim = minor_to_major[i];
llvm::Value* addend =
b->CreateMul(reduced_output_index[dim],
reduced_output_index.GetConstantWithIndexType(multiplier),
"linearizing",
/*HasNUW=*/true, /*HasNSW=*/true);
linear_index = b->CreateAdd(linear_index, addend, "",
/*HasNUW=*/true, /*HasNSW=*/true);
multiplier *= bounds[dim];
}
return IrArray::Index(linear_index, unreduced_output_shape, b);
}
KernelMappingScheme::KernelMappingScheme(
absl::Span<const int64> dims_in_elems, int64 tile_size_y, int64 tile_size_x,
absl::Span<const int64> req_block_sizes, int64 num_threads_y,
@ -143,13 +145,14 @@ KernelMappingScheme::KernelMappingScheme(
<< "]";
}
IrArray::Index KernelMappingScheme::GetReshapedOutputIndex(
const IrArray::Index& output_index, const Shape& reshaped_output_shape) {
DCHECK_EQ(output_index.size(), dims_in_elems_.size());
IrArray::Index KernelMappingScheme::GetUnnormalizedIndex(
const IrArray::Index& normalized_shape_index,
const Shape& unnormalized_shape) {
DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size());
Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
reshaped_output_shape.element_type(), GetDimensionsInElements());
return llvm_ir::GetUnreducedOutputIndex(output_index, output_shape,
reshaped_output_shape, b_);
unnormalized_shape.element_type(), GetDimensionsInElements());
return GetReshapedIndex(normalized_shape_index, output_shape,
unnormalized_shape, b_);
}
IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) {

View File

@ -28,24 +28,15 @@ namespace llvm_ir {
// If a shape can be viewed as three logical components 0-1-2 in the order of
// major to minor, a 0-2-1-transpose changes the order of such logical
// components to 0-2-1. We call the shape being transposed the input shape and
// the transposed shape the output shape. The logical view of the input and
// output shapes for the transpose are called the 0-1-2 shape or reduced input
// shape and the 0-2-1 shape or the reduced output shape respectively. The
// original input and output shapes are called the unreduced input and output
// shapes.
// the transposed shape the output shape. The logical view of the input/output
// shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized
// shapes. The original input/output shapes are called unnormalized shapes.
//
// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
// reduced shape of `b` or the 0-2-1 shape.
// normalized shape of `b` or the 0-2-1 shape.
absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
const Shape& b);
// Return the unreduced output index corresponding to the given reduced output
// index.
IrArray::Index GetUnreducedOutputIndex(
const IrArray::Index& reduced_output_index,
const Shape& reduced_output_shape, const Shape& unreduced_output_shape,
llvm::IRBuilder<>* b);
// A tile is a spatial subdivision of a tensor. We group tensor elements into
// tiles so that we can launch kernels to process the tensor elements in blocks
// of tiles.
@ -158,8 +149,9 @@ class KernelMappingScheme {
std::tuple<llvm::Value*, llvm::Value*> EmitThreadYXCoordinate(
llvm::Type* index_ty);
IrArray::Index GetReshapedOutputIndex(const IrArray::Index& output_index,
const Shape& reshaped_output_shape);
IrArray::Index GetUnnormalizedIndex(
const IrArray::Index& normalized_shape_index,
const Shape& unnormalized_shape);
llvm::GlobalVariable* GetSharedMemoryBufferForElementType(
llvm::Type* elem_ty, absl::string_view buffer_name);