Don't replace Transposes with Bitcasts on the GPU backend.
We can generate the index for accessing the memory more efficiently if we still know which dimensions are permuted. Outside of fusion nodes, we still want to replace transposes with bitcasts, so we add another run of AlgebraicSimplifier at the end of OptimizeHloModule(). PiperOrigin-RevId: 304354090 Change-Id: I7314476397a6e24dd32b4a85f90d0fa243db382f
This commit is contained in:
parent
5c209b39d1
commit
7d529df64c
@ -4128,7 +4128,9 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
|
||||
return ReplaceInstruction(transpose, operand);
|
||||
}
|
||||
|
||||
if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) {
|
||||
if (options_.is_layout_sensitive() &&
|
||||
options_.replace_transpose_with_bitcast() &&
|
||||
TransposeIsBitcast(transpose)) {
|
||||
ReplaceWithBitcast(transpose);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -113,6 +113,14 @@ class AlgebraicSimplifierOptions {
|
||||
|
||||
bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; }
|
||||
|
||||
void set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast) {
|
||||
replace_transpose_with_bitcast_ = replace_transpose_with_bitcast;
|
||||
}
|
||||
|
||||
bool replace_transpose_with_bitcast() const {
|
||||
return replace_transpose_with_bitcast_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Metadata struct can be used to store any metadata information encapsulated
|
||||
// with the AlgebraicSimplierOptions that can be later used in an
|
||||
@ -133,6 +141,7 @@ class AlgebraicSimplifierOptions {
|
||||
bool enable_conv_simplification_{true};
|
||||
bool enable_window_reduce_to_reduce_replacement_{true};
|
||||
bool enable_reduce_of_reshape_{true};
|
||||
bool replace_transpose_with_bitcast_{true};
|
||||
int64 very_small_gather_size_{4};
|
||||
Metadata metadata_;
|
||||
};
|
||||
|
||||
@ -2437,7 +2437,7 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
// Verify that the reshape is replaced.
|
||||
// Verify that the transpose is replaced.
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Bitcast(m::Parameter(0))));
|
||||
}
|
||||
@ -2464,10 +2464,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
options.set_is_layout_sensitive(true);
|
||||
// Don't replace transposes with bitcasts.
|
||||
options.set_replace_transpose_with_bitcast(false);
|
||||
AlgebraicSimplifier simplifier_no_replace(options);
|
||||
ASSERT_FALSE(simplifier_no_replace.Run(m.get()).ValueOrDie());
|
||||
|
||||
// Replace transposes with bitcasts if possible.
|
||||
options.set_replace_transpose_with_bitcast(true);
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
// Verify that the reshape is replaced.
|
||||
// Verify that the transpose is replaced.
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Bitcast(m::Parameter(0))));
|
||||
}
|
||||
|
||||
@ -122,8 +122,11 @@ void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands(
|
||||
operand = fusion_->operand(operand->parameter_number());
|
||||
}
|
||||
// For simplicity we assume that all shape and layout changing
|
||||
// operations invalidate index reuse.
|
||||
if (Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
// operations except Transposes invalidate index reuse. Transposes are
|
||||
// special: although they are shape changing, we can reuse the
|
||||
// multi-dimensional index for the operand by permuting it.
|
||||
if (instruction->opcode() == HloOpcode::kTranspose ||
|
||||
Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
instruction->shape())) {
|
||||
// If the index is reused, it means the operand gets index values
|
||||
// from the same set of (indirect) users as 'instruction' itself.
|
||||
|
||||
@ -205,6 +205,15 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
pipeline.AddPass<ZeroSizedHloElimination>();
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
// When transposes appear in a fusion node, we can easily adjust the
|
||||
// multi-dimensional index to create the one needed for the operand. This
|
||||
// is not as easy with bitcasts, because we don't have the information
|
||||
// readily available which dimensions are permuted. In addition to that,
|
||||
// if we have a transpose and a reshape next to each other, they will both
|
||||
// be replaced by a bitcast, and we replace bitcast(bitcast) with one
|
||||
// bitcast. This leads to having to linearize and then delinearize the
|
||||
// index.
|
||||
options.set_replace_transpose_with_bitcast(false);
|
||||
pass.AddPass<AlgebraicSimplifier>(options);
|
||||
// AlgebraicSimplifier may add contracting dimensions to a dot.
|
||||
pass.AddPass<DotDecomposer>();
|
||||
@ -306,6 +315,13 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
/*combine_threshold_count=*/256);
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
{
|
||||
// Now we allow to replace any transposes outside of fusions with bitcasts.
|
||||
HloPassPipeline pipeline("final_algebraic_simplifier");
|
||||
AlgebraicSimplifierOptions options;
|
||||
options.set_is_layout_sensitive(true);
|
||||
pipeline.AddPass<AlgebraicSimplifier>(options);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -372,6 +388,15 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
|
||||
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
|
||||
AlgebraicSimplifierOptions options;
|
||||
options.set_is_layout_sensitive(true);
|
||||
// When transposes appear in a fusion node, we can easily adjust the
|
||||
// multi-dimensional index to create the one needed for the operand. This
|
||||
// is not as easy with bitcasts, because we don't have the information
|
||||
// readily available which dimensions are permuted. In addition to that,
|
||||
// if we have a transpose and a reshape next to each other, they will both
|
||||
// be replaced by a bitcast, and we replace bitcast(bitcast) with one
|
||||
// bitcast. This leads to having to linearize and then delinearize the
|
||||
// index.
|
||||
options.set_replace_transpose_with_bitcast(false);
|
||||
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
|
||||
|
||||
if (RequireDeterminism() ||
|
||||
|
||||
@ -1363,7 +1363,7 @@ GetHloBufferSlices(const HloInstruction* hlo,
|
||||
// appear before any GTE instructions, because it's illegal to bitcast to a
|
||||
// tuple type.
|
||||
const HloInstruction* parent = instr;
|
||||
while (parent->opcode() == HloOpcode::kBitcast) {
|
||||
while (parent->IsEffectiveBitcast()) {
|
||||
parent = parent->operand(0);
|
||||
|
||||
auto slice = buffer_assn.GetUniqueSlice(parent, {});
|
||||
|
||||
@ -131,6 +131,15 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
// When transposes appear in a fusion node, we can easily adjust the
|
||||
// multi-dimensional index to create the one needed for the operand. This
|
||||
// is not as easy with bitcasts, because we don't have the information
|
||||
// readily available which dimensions are permuted. In addition to that,
|
||||
// if we have a transpose and a reshape next to each other, they will both
|
||||
// be replaced by a bitcast, and we replace bitcast(bitcast) with one
|
||||
// bitcast. This leads to having to linearize and then delinearize the
|
||||
// index.
|
||||
options.set_replace_transpose_with_bitcast(false);
|
||||
options.set_cudnn_batchnorm_forward_training_metadata(
|
||||
kCudnnBatchNormForwardTrainingCallTarget);
|
||||
pass.AddPass<AlgebraicSimplifier>(options);
|
||||
|
||||
@ -498,7 +498,10 @@ Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
|
||||
Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
|
||||
if (transpose->IsEffectiveBitcast()) {
|
||||
return HandleBitcast(transpose);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -2160,6 +2160,13 @@ Status HloInstruction::ReplaceAllUsesWithDifferentShape(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HloInstruction::IsEffectiveBitcast() const {
|
||||
return opcode_ == HloOpcode::kBitcast ||
|
||||
(opcode_ == HloOpcode::kTranspose &&
|
||||
ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
|
||||
dimensions()));
|
||||
}
|
||||
|
||||
HloComputation* HloInstruction::to_apply() const {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kCall:
|
||||
|
||||
@ -1233,6 +1233,11 @@ class HloInstruction {
|
||||
const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
|
||||
}
|
||||
|
||||
// Returns true whether this instruction is effectively a bitcast. Currently,
|
||||
// this means it either is a bitcast, or it is a transpose that is effectively
|
||||
// a bitcast.
|
||||
bool IsEffectiveBitcast() const;
|
||||
|
||||
// Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
|
||||
// The setter should only be called by HloModule or HloComputation methods.
|
||||
//
|
||||
|
||||
@ -257,8 +257,11 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||
}
|
||||
for (const auto* operand : instruction->operands()) {
|
||||
// For simplicity we assume that all shape and layout changing
|
||||
// operations invalidate index reuse.
|
||||
if (Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
// operations except Transposes invalidate index reuse. Transposes are
|
||||
// special: although they are shape changing, we can reuse the
|
||||
// multi-dimensional index for the operand by permuting it.
|
||||
if (instruction->opcode() == HloOpcode::kTranspose ||
|
||||
Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
instruction->shape())) {
|
||||
// If the index is reused, it means the operand gets index values
|
||||
// from the same set of (indirect) users as 'instruction' itself.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user