diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a676d65c5bd..40048f79929 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -116,7 +116,9 @@ cc_library( deps = [ ":bfloat16_support", ":hlo", + ":hlo_dce", ":hlo_pass", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index f1ab34d6141..49e354e53c0 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -18,8 +18,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,6 +31,8 @@ limitations under the License. namespace xla { +namespace { + class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { public: explicit BFloat16NormalizationVisitor( @@ -51,19 +55,30 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // independently. Status HandleMultipleOutputs(HloInstruction* hlo); - // Inserts a conversion HLO that changes the given HLO's output type. - Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, + // Creates a copy of `hlo` with subshapes matching `from` type converted to + // `to` type. If no matching subshapes are found, returns the original `hlo`. + StatusOr ConvertType(HloInstruction* hlo, PrimitiveType from, + PrimitiveType to, + HloComputation* computation); + + // Inserts a conversion HLO that changes the given HLO's output type. If the + // output is a tuple, change all elements that match the from type. + Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType from, + PrimitiveType to, HloComputation* computation); // Changes the output type to the specified type, then inserts a conversion - // to the original type. + // to the original type. If the output is a tuple, change all elements that + // match the from type. Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo, + PrimitiveType from, PrimitiveType to, HloComputation* computation); - // Inserts a conversion HLO that changes the given HLO's operand type. + // Inserts a conversion HLO that changes the given HLO's operand type. If the + // operand is a tuple, change all elements that match the from type. Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx, - PrimitiveType to, + PrimitiveType from, PrimitiveType to, HloComputation* computation); // Inserts conversion HLOs to replace the called computations' BF16 @@ -77,47 +92,140 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { bool changed_ = false; }; +int64 CountSubshapesWithMatchingType(const Shape& shape, PrimitiveType type) { + int64 count = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() == type) { + ++count; + } + }); + return count; +} + +int64 ShapeLeafCount(const Shape& shape) { + int64 count = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsLeafIndex(shape, index)) { + ++count; + } + }); + return count; +} + +StatusOr BFloat16NormalizationVisitor::ConvertType( + HloInstruction* hlo, PrimitiveType from, PrimitiveType to, + HloComputation* computation) { + if (CountSubshapesWithMatchingType(hlo->shape(), from) == 0) { + return hlo; + } + // If `hlo` is a convert from `to` to `from`, then we can return its operand, + // if it is a BF16->F32 convert which doesn't do rounding. + if (hlo->opcode() == HloOpcode::kConvert && + hlo->operand(0)->shape().element_type() == to && to == BF16 && + from == F32) { + return hlo->mutable_operand(0); + } + TF_ASSIGN_OR_RETURN( + auto new_hlo, + computation->DeepCopyInstructionWithCustomCopier( + hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + const auto& original_subshape = + ShapeUtil::GetSubshape(hlo->shape(), leaf_index); + if (original_subshape.element_type() != from) { + return leaf; + } + auto new_subshape = + ShapeUtil::ChangeElementType(original_subshape, to); + bfloat16_normalization_->UpdateLayout(&new_subshape); + return computation->AddInstruction( + HloInstruction::CreateConvert(new_subshape, leaf)); + })); + return new_hlo; +} + Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( - HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + HloInstruction* hlo, PrimitiveType from, PrimitiveType to, + HloComputation* computation) { bool is_root = computation->root_instruction() == hlo; std::vector materialized_users = hlo->users(); - // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith. - auto convert = computation->AddInstruction( - HloInstruction::CreateConvert(hlo->shape(), hlo)); + + TF_ASSIGN_OR_RETURN(auto new_hlo, ConvertType(hlo, from, to, computation)); + if (new_hlo == hlo) { + return Status::OK(); + } + for (auto* user : materialized_users) { - if (user->opcode() == HloOpcode::kConvert && - user->shape().element_type() == F32) { - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); - } else { - TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); - } + TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); } if (is_root) { - computation->set_root_instruction(convert); + computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true); } - convert->mutable_shape()->set_element_type(to); - bfloat16_normalization_->UpdateLayout(convert->mutable_shape()); changed_ = true; return Status::OK(); } Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( - HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { - auto original_type = hlo->shape().element_type(); - hlo->mutable_shape()->set_element_type(to); + HloInstruction* hlo, PrimitiveType from, PrimitiveType to, + HloComputation* computation) { + auto original_shape = hlo->shape(); + if (CountSubshapesWithMatchingType(original_shape, from) == 0) { + return Status::OK(); + } + ShapeUtil::ForEachMutableSubshape( + hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) { + if (subshape->element_type() == from) { + subshape->set_element_type(to); + } + }); bfloat16_normalization_->UpdateLayout(hlo->mutable_shape()); - return InsertConvertAfterOutput(hlo, original_type, computation); + bool is_root = computation->root_instruction() == hlo; + std::vector materialized_users = hlo->users(); + TF_ASSIGN_OR_RETURN( + auto new_hlo, + computation->DeepCopyInstructionWithCustomCopier( + hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + const auto& original_subshape = + ShapeUtil::GetSubshape(original_shape, leaf_index); + if (original_subshape.element_type() == + leaf->shape().element_type()) { + return leaf; + } + return computation->AddInstruction( + HloInstruction::CreateConvert(original_subshape, leaf)); + })); + + for (auto* user : materialized_users) { + // If the user is a BF16 -> F32 convert, we can replace it with `hlo`, which + // has its input changed to F32. + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == to && to == F32 && from == BF16) { + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + } else { + TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); + } + } + if (is_root) { + computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true); + } + changed_ = true; + return Status::OK(); } Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( - HloInstruction* hlo, int64 operand_idx, PrimitiveType to, - HloComputation* computation) { + HloInstruction* hlo, int64 operand_idx, PrimitiveType from, + PrimitiveType to, HloComputation* computation) { auto operand = hlo->mutable_operand(operand_idx); - auto shape = ShapeUtil::ChangeElementType(operand->shape(), to); - bfloat16_normalization_->UpdateLayout(&shape); - auto convert = computation->AddInstruction( - HloInstruction::CreateConvert(shape, operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); + TF_ASSIGN_OR_RETURN(auto new_operand, + ConvertType(operand, from, to, computation)); + if (new_operand == operand) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + hlo->ReplaceOperandWithDifferentShape(operand_idx, new_operand)); changed_ = true; return Status::OK(); } @@ -139,16 +247,12 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( }); for (auto& comp_pair : cloned_computations) { auto comp = comp_pair.second; - if (comp->root_instruction()->shape().element_type() == BF16) { - TF_RETURN_IF_ERROR( - InsertConvertAfterOutput(comp->root_instruction(), F32, comp)); - } + TF_RETURN_IF_ERROR( + InsertConvertAfterOutput(comp->root_instruction(), BF16, F32, comp)); for (auto* param : comp->parameter_instructions()) { - if (param->shape().element_type() == BF16) { - // This changes the parameter to F32 then inserts a convert after it. - TF_RETURN_IF_ERROR( - ChangeOutputTypeThenInsertConvertBack(param, F32, comp)); - } + // This changes the parameter to F32 then inserts a convert after it. + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(param, BF16, F32, comp)); } } return Status::OK(); @@ -163,6 +267,8 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( bool has_unsupported_bf16_operand = false; bool has_unsupported_bf16_output = false; for (int64 i = 0; i < hlo->operand_count(); ++i) { + CHECK(hlo->operand(i)->shape().IsArray()); + CHECK(ShapeUtil::GetSubshape(hlo->shape(), {i}).IsArray()); operand_types[i] = hlo->operand(i)->shape().element_type(); output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type(); if (operand_types[i] == F32) { @@ -203,7 +309,8 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( for (int64 i = 0; i < hlo->operand_count(); ++i) { if (should_convert_operand(i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_)); f32_count += 1; bf16_count -= 1; } @@ -275,36 +382,34 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { int f32_count = 0; - int bf16_count = 1; + int bf16_count = 0; for (int64 i = 0; i < hlo->operand_count(); ++i) { - if (hlo->operand(i)->shape().element_type() == F32) { - f32_count += 1; - } else if (hlo->operand(i)->shape().element_type() == BF16) { - bf16_count += 1; - } + f32_count += CountSubshapesWithMatchingType(hlo->operand(i)->shape(), F32); + bf16_count += + CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16); } - if (hlo->shape().element_type() == F32) { - f32_count += 1; - } else if (hlo->shape().element_type() == BF16) { - bf16_count += 1; - } + f32_count += CountSubshapesWithMatchingType(hlo->shape(), F32); + bf16_count += CountSubshapesWithMatchingType(hlo->shape(), BF16); std::vector bf16_called_comps; for (auto* comp : hlo->called_computations()) { bool comp_has_bf16 = false; - if (comp->root_instruction()->shape().element_type() == F32) { - f32_count += 1; - } else if (comp->root_instruction()->shape().element_type() == BF16) { - bf16_count += 1; + f32_count += + CountSubshapesWithMatchingType(comp->root_instruction()->shape(), F32); + int64 bf16_count_comp_root = + CountSubshapesWithMatchingType(comp->root_instruction()->shape(), BF16); + if (bf16_count_comp_root > 0) { + bf16_count += bf16_count_comp_root; comp_has_bf16 = true; } for (auto* param : comp->parameter_instructions()) { - if (param->shape().element_type() == F32) { - f32_count += 1; - } else if (param->shape().element_type() == BF16) { - bf16_count += 1; + f32_count += CountSubshapesWithMatchingType(param->shape(), F32); + int64 bf16_count_comp_param = + CountSubshapesWithMatchingType(param->shape(), BF16); + if (bf16_count_comp_param > 0) { + bf16_count += bf16_count_comp_param; comp_has_bf16 = true; } } @@ -315,21 +420,27 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { // Resolve unsupported BF16 operands. for (int i = 0; i < hlo->operand_count(); ++i) { - if (hlo->operand(i)->shape().element_type() == BF16 && + int64 bf16_count_in_operand = + CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16); + if (bf16_count_in_operand > 0 && !bfloat16_support_->SupportsBF16Operand(*hlo, i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); - bf16_count -= 1; - f32_count += 1; + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_)); + bf16_count -= bf16_count_in_operand; + f32_count += bf16_count_in_operand; } } // Resolve unsupported BF16 output. - if (hlo->shape().element_type() == BF16 && - !bfloat16_support_->SupportsBF16Output(*hlo)) { - TF_RETURN_IF_ERROR( - ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); - bf16_count -= 1; - f32_count += 1; + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { + int64 bf16_count_in_hlo = + CountSubshapesWithMatchingType(hlo->shape(), BF16); + if (bf16_count_in_hlo > 0) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, BF16, F32, computation_)); + bf16_count -= bf16_count_in_hlo; + f32_count += bf16_count_in_hlo; + } } // Resolve unsupported mixed precision after resolving unsupported BF16 @@ -341,10 +452,12 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { } // See if we can change everything to BF16. if (hlo->called_computations().empty() && - hlo->shape().element_type() == BF16) { + CountSubshapesWithMatchingType(hlo->shape(), BF16) == + ShapeLeafCount(hlo->shape())) { bool can_use_bf16 = true; for (int i = 0; i < hlo->operand_count(); ++i) { - if (hlo->operand(i)->shape().element_type() == BF16) { + if (CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16) == + ShapeLeafCount(hlo->operand(i)->shape())) { continue; } if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) || @@ -358,22 +471,17 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { } if (can_use_bf16) { for (int i = 0; i < hlo->operand_count(); ++i) { - if (hlo->operand(i)->shape().element_type() == F32) { - TF_RETURN_IF_ERROR( - InsertConvertBeforeOperand(hlo, i, BF16, computation_)); - } + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, F32, BF16, computation_)); } return Status::OK(); } } - if (hlo->shape().element_type() == BF16) { - TF_RETURN_IF_ERROR( - ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); - } + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, BF16, F32, computation_)); for (int i = 0; i < hlo->operand_count(); ++i) { - if (hlo->operand(i)->shape().element_type() == BF16) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); - } + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_)); } return ConvertCalledComputations(hlo, bf16_called_comps); } @@ -385,6 +493,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { if (hlo->opcode() == HloOpcode::kTuple || // hlo->opcode() == HloOpcode::kGetTupleElement || // hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kDomain || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // hlo->opcode() == HloOpcode::kConvert || // @@ -410,6 +519,8 @@ Status BFloat16NormalizationVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } +} // namespace + StatusOr BFloat16Normalization::Run(HloModule* module) { XLA_VLOG_LINES( 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); @@ -419,6 +530,12 @@ StatusOr BFloat16Normalization::Run(HloModule* module) { } XLA_VLOG_LINES(2, "BFloat16Normalization::Run(), after:\n" + module->ToString()); + if (visitor.changed()) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } return visitor.changed(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index dbbcb358ce3..8abc3c7c5dd 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -40,7 +40,8 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kAllToAll) { return true; } if (hlo.opcode() == HloOpcode::kDot) { @@ -54,7 +55,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kAllToAll) { return true; } return false; @@ -258,19 +260,76 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, /*replica_groups=*/{}, /*channel_id=*/absl::nullopt)); - HloInstruction* gte = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); - EXPECT_EQ(computation->root_instruction(), gte); - EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); EXPECT_EQ(crs->operand(1)->shape().element_type(), F32); EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) { + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + + std::vector replica_groups(1); + replica_groups[0].add_replica_ids(0); + replica_groups[0].add_replica_ids(1); + HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll( + ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a}, + replica_groups)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), a2a); + EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {0}).element_type(), BF16); + EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {1}).element_type(), BF16); + EXPECT_EQ(a2a->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(a2a->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(a2a->operand(1)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(a2a->operand(1)->shape().element_type(), BF16); +} + +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) { + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + + std::vector replica_groups(1); + replica_groups[0].add_replica_ids(0); + replica_groups[0].add_replica_ids(1); + HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll( + ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a}, + replica_groups)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple); + EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {0}).element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {1}).element_type(), F32); + EXPECT_EQ(a2a->operand(0)->opcode(), HloOpcode::kParameter); + EXPECT_EQ(a2a->operand(0)->shape().element_type(), F32); + EXPECT_EQ(a2a->operand(1)->opcode(), HloOpcode::kParameter); + EXPECT_EQ(a2a->operand(1)->shape().element_type(), F32); +} + TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -288,15 +347,14 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), {key, value}, 0, /*is_stable=*/false, &builder, module.get())); - HloInstruction* gte = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); - EXPECT_EQ(computation->root_instruction(), gte); - EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); }