[XLA] Make BF16 normalization better handle tuple-shaped inputs/outputs.

Use per-subshape accounting and conversion for mismatched types.

PiperOrigin-RevId: 275410631
Change-Id: I098fcb7cbdfebc015ba68156563f8350077fd4ea
This commit is contained in:
Yuanzhong Xu 2019-10-17 23:02:00 -07:00 committed by TensorFlower Gardener
parent 94794f435d
commit 976e7d7f58
3 changed files with 266 additions and 89 deletions

View File

@ -116,7 +116,9 @@ cc_library(
deps = [
":bfloat16_support",
":hlo",
":hlo_dce",
":hlo_pass",
":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",

View File

@ -18,8 +18,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -29,6 +31,8 @@ limitations under the License.
namespace xla {
namespace {
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
public:
explicit BFloat16NormalizationVisitor(
@ -51,19 +55,30 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
// independently.
Status HandleMultipleOutputs(HloInstruction* hlo);
// Inserts a conversion HLO that changes the given HLO's output type.
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
// Creates a copy of `hlo` with subshapes matching `from` type converted to
// `to` type. If no matching subshapes are found, returns the original `hlo`.
StatusOr<HloInstruction*> ConvertType(HloInstruction* hlo, PrimitiveType from,
PrimitiveType to,
HloComputation* computation);
// Inserts a conversion HLO that changes the given HLO's output type. If the
// output is a tuple, change all elements that match the from type.
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType from,
PrimitiveType to,
HloComputation* computation);
// Changes the output type to the specified type, then inserts a conversion
// to the original type.
// to the original type. If the output is a tuple, change all elements that
// match the from type.
Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
PrimitiveType from,
PrimitiveType to,
HloComputation* computation);
// Inserts a conversion HLO that changes the given HLO's operand type.
// Inserts a conversion HLO that changes the given HLO's operand type. If the
// operand is a tuple, change all elements that match the from type.
Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
PrimitiveType to,
PrimitiveType from, PrimitiveType to,
HloComputation* computation);
// Inserts conversion HLOs to replace the called computations' BF16
@ -77,47 +92,140 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
bool changed_ = false;
};
int64 CountSubshapesWithMatchingType(const Shape& shape, PrimitiveType type) {
int64 count = 0;
ShapeUtil::ForEachSubshape(
shape, [&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.element_type() == type) {
++count;
}
});
return count;
}
int64 ShapeLeafCount(const Shape& shape) {
int64 count = 0;
ShapeUtil::ForEachSubshape(
shape, [&](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsLeafIndex(shape, index)) {
++count;
}
});
return count;
}
StatusOr<HloInstruction*> BFloat16NormalizationVisitor::ConvertType(
HloInstruction* hlo, PrimitiveType from, PrimitiveType to,
HloComputation* computation) {
if (CountSubshapesWithMatchingType(hlo->shape(), from) == 0) {
return hlo;
}
// If `hlo` is a convert from `to` to `from`, then we can return its operand,
// if it is a BF16->F32 convert which doesn't do rounding.
if (hlo->opcode() == HloOpcode::kConvert &&
hlo->operand(0)->shape().element_type() == to && to == BF16 &&
from == F32) {
return hlo->mutable_operand(0);
}
TF_ASSIGN_OR_RETURN(
auto new_hlo,
computation->DeepCopyInstructionWithCustomCopier(
hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
HloComputation* comp) {
const auto& original_subshape =
ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
if (original_subshape.element_type() != from) {
return leaf;
}
auto new_subshape =
ShapeUtil::ChangeElementType(original_subshape, to);
bfloat16_normalization_->UpdateLayout(&new_subshape);
return computation->AddInstruction(
HloInstruction::CreateConvert(new_subshape, leaf));
}));
return new_hlo;
}
Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
HloInstruction* hlo, PrimitiveType from, PrimitiveType to,
HloComputation* computation) {
bool is_root = computation->root_instruction() == hlo;
std::vector<HloInstruction*> materialized_users = hlo->users();
// Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
auto convert = computation->AddInstruction(
HloInstruction::CreateConvert(hlo->shape(), hlo));
TF_ASSIGN_OR_RETURN(auto new_hlo, ConvertType(hlo, from, to, computation));
if (new_hlo == hlo) {
return Status::OK();
}
for (auto* user : materialized_users) {
if (user->opcode() == HloOpcode::kConvert &&
user->shape().element_type() == F32) {
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
} else {
TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
}
TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo));
}
if (is_root) {
computation->set_root_instruction(convert);
computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true);
}
convert->mutable_shape()->set_element_type(to);
bfloat16_normalization_->UpdateLayout(convert->mutable_shape());
changed_ = true;
return Status::OK();
}
Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
auto original_type = hlo->shape().element_type();
hlo->mutable_shape()->set_element_type(to);
HloInstruction* hlo, PrimitiveType from, PrimitiveType to,
HloComputation* computation) {
auto original_shape = hlo->shape();
if (CountSubshapesWithMatchingType(original_shape, from) == 0) {
return Status::OK();
}
ShapeUtil::ForEachMutableSubshape(
hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) {
if (subshape->element_type() == from) {
subshape->set_element_type(to);
}
});
bfloat16_normalization_->UpdateLayout(hlo->mutable_shape());
return InsertConvertAfterOutput(hlo, original_type, computation);
bool is_root = computation->root_instruction() == hlo;
std::vector<HloInstruction*> materialized_users = hlo->users();
TF_ASSIGN_OR_RETURN(
auto new_hlo,
computation->DeepCopyInstructionWithCustomCopier(
hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
HloComputation* comp) {
const auto& original_subshape =
ShapeUtil::GetSubshape(original_shape, leaf_index);
if (original_subshape.element_type() ==
leaf->shape().element_type()) {
return leaf;
}
return computation->AddInstruction(
HloInstruction::CreateConvert(original_subshape, leaf));
}));
for (auto* user : materialized_users) {
// If the user is a BF16 -> F32 convert, we can replace it with `hlo`, which
// has its input changed to F32.
if (user->opcode() == HloOpcode::kConvert &&
user->shape().element_type() == to && to == F32 && from == BF16) {
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
} else {
TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo));
}
}
if (is_root) {
computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true);
}
changed_ = true;
return Status::OK();
}
Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
HloComputation* computation) {
HloInstruction* hlo, int64 operand_idx, PrimitiveType from,
PrimitiveType to, HloComputation* computation) {
auto operand = hlo->mutable_operand(operand_idx);
auto shape = ShapeUtil::ChangeElementType(operand->shape(), to);
bfloat16_normalization_->UpdateLayout(&shape);
auto convert = computation->AddInstruction(
HloInstruction::CreateConvert(shape, operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
TF_ASSIGN_OR_RETURN(auto new_operand,
ConvertType(operand, from, to, computation));
if (new_operand == operand) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
hlo->ReplaceOperandWithDifferentShape(operand_idx, new_operand));
changed_ = true;
return Status::OK();
}
@ -139,16 +247,12 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations(
});
for (auto& comp_pair : cloned_computations) {
auto comp = comp_pair.second;
if (comp->root_instruction()->shape().element_type() == BF16) {
TF_RETURN_IF_ERROR(
InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
}
TF_RETURN_IF_ERROR(
InsertConvertAfterOutput(comp->root_instruction(), BF16, F32, comp));
for (auto* param : comp->parameter_instructions()) {
if (param->shape().element_type() == BF16) {
// This changes the parameter to F32 then inserts a convert after it.
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
}
// This changes the parameter to F32 then inserts a convert after it.
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(param, BF16, F32, comp));
}
}
return Status::OK();
@ -163,6 +267,8 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
bool has_unsupported_bf16_operand = false;
bool has_unsupported_bf16_output = false;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
CHECK(hlo->operand(i)->shape().IsArray());
CHECK(ShapeUtil::GetSubshape(hlo->shape(), {i}).IsArray());
operand_types[i] = hlo->operand(i)->shape().element_type();
output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type();
if (operand_types[i] == F32) {
@ -203,7 +309,8 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (should_convert_operand(i)) {
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
TF_RETURN_IF_ERROR(
InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_));
f32_count += 1;
bf16_count -= 1;
}
@ -275,36 +382,34 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
int f32_count = 0;
int bf16_count = 1;
int bf16_count = 0;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (hlo->operand(i)->shape().element_type() == F32) {
f32_count += 1;
} else if (hlo->operand(i)->shape().element_type() == BF16) {
bf16_count += 1;
}
f32_count += CountSubshapesWithMatchingType(hlo->operand(i)->shape(), F32);
bf16_count +=
CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16);
}
if (hlo->shape().element_type() == F32) {
f32_count += 1;
} else if (hlo->shape().element_type() == BF16) {
bf16_count += 1;
}
f32_count += CountSubshapesWithMatchingType(hlo->shape(), F32);
bf16_count += CountSubshapesWithMatchingType(hlo->shape(), BF16);
std::vector<HloComputation*> bf16_called_comps;
for (auto* comp : hlo->called_computations()) {
bool comp_has_bf16 = false;
if (comp->root_instruction()->shape().element_type() == F32) {
f32_count += 1;
} else if (comp->root_instruction()->shape().element_type() == BF16) {
bf16_count += 1;
f32_count +=
CountSubshapesWithMatchingType(comp->root_instruction()->shape(), F32);
int64 bf16_count_comp_root =
CountSubshapesWithMatchingType(comp->root_instruction()->shape(), BF16);
if (bf16_count_comp_root > 0) {
bf16_count += bf16_count_comp_root;
comp_has_bf16 = true;
}
for (auto* param : comp->parameter_instructions()) {
if (param->shape().element_type() == F32) {
f32_count += 1;
} else if (param->shape().element_type() == BF16) {
bf16_count += 1;
f32_count += CountSubshapesWithMatchingType(param->shape(), F32);
int64 bf16_count_comp_param =
CountSubshapesWithMatchingType(param->shape(), BF16);
if (bf16_count_comp_param > 0) {
bf16_count += bf16_count_comp_param;
comp_has_bf16 = true;
}
}
@ -315,21 +420,27 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
// Resolve unsupported BF16 operands.
for (int i = 0; i < hlo->operand_count(); ++i) {
if (hlo->operand(i)->shape().element_type() == BF16 &&
int64 bf16_count_in_operand =
CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16);
if (bf16_count_in_operand > 0 &&
!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
bf16_count -= 1;
f32_count += 1;
TF_RETURN_IF_ERROR(
InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_));
bf16_count -= bf16_count_in_operand;
f32_count += bf16_count_in_operand;
}
}
// Resolve unsupported BF16 output.
if (hlo->shape().element_type() == BF16 &&
!bfloat16_support_->SupportsBF16Output(*hlo)) {
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
bf16_count -= 1;
f32_count += 1;
if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
int64 bf16_count_in_hlo =
CountSubshapesWithMatchingType(hlo->shape(), BF16);
if (bf16_count_in_hlo > 0) {
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(hlo, BF16, F32, computation_));
bf16_count -= bf16_count_in_hlo;
f32_count += bf16_count_in_hlo;
}
}
// Resolve unsupported mixed precision after resolving unsupported BF16
@ -341,10 +452,12 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
}
// See if we can change everything to BF16.
if (hlo->called_computations().empty() &&
hlo->shape().element_type() == BF16) {
CountSubshapesWithMatchingType(hlo->shape(), BF16) ==
ShapeLeafCount(hlo->shape())) {
bool can_use_bf16 = true;
for (int i = 0; i < hlo->operand_count(); ++i) {
if (hlo->operand(i)->shape().element_type() == BF16) {
if (CountSubshapesWithMatchingType(hlo->operand(i)->shape(), BF16) ==
ShapeLeafCount(hlo->operand(i)->shape())) {
continue;
}
if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) ||
@ -358,22 +471,17 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
}
if (can_use_bf16) {
for (int i = 0; i < hlo->operand_count(); ++i) {
if (hlo->operand(i)->shape().element_type() == F32) {
TF_RETURN_IF_ERROR(
InsertConvertBeforeOperand(hlo, i, BF16, computation_));
}
TF_RETURN_IF_ERROR(
InsertConvertBeforeOperand(hlo, i, F32, BF16, computation_));
}
return Status::OK();
}
}
if (hlo->shape().element_type() == BF16) {
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
}
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(hlo, BF16, F32, computation_));
for (int i = 0; i < hlo->operand_count(); ++i) {
if (hlo->operand(i)->shape().element_type() == BF16) {
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
}
TF_RETURN_IF_ERROR(
InsertConvertBeforeOperand(hlo, i, BF16, F32, computation_));
}
return ConvertCalledComputations(hlo, bf16_called_comps);
}
@ -385,6 +493,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kDomain || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
hlo->opcode() == HloOpcode::kConvert || //
@ -410,6 +519,8 @@ Status BFloat16NormalizationVisitor::Preprocess(HloInstruction* hlo) {
return Status::OK();
}
} // namespace
StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
@ -419,6 +530,12 @@ StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
}
XLA_VLOG_LINES(2,
"BFloat16Normalization::Run(), after:\n" + module->ToString());
if (visitor.changed()) {
TupleSimplifier tuple_simplifier;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
HloDCE dce;
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
return visitor.changed();
}

View File

@ -40,7 +40,8 @@ class TestBFloat16Support : public BFloat16Support {
hlo.opcode() == HloOpcode::kSubtract ||
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
hlo.opcode() == HloOpcode::kGetTupleElement ||
hlo.opcode() == HloOpcode::kAllToAll) {
return true;
}
if (hlo.opcode() == HloOpcode::kDot) {
@ -54,7 +55,8 @@ class TestBFloat16Support : public BFloat16Support {
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kSubtract ||
hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
hlo.opcode() == HloOpcode::kGetTupleElement ||
hlo.opcode() == HloOpcode::kAllToAll) {
return true;
}
return false;
@ -258,19 +260,76 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
/*replica_groups=*/{},
/*channel_id=*/absl::nullopt));
HloInstruction* gte = builder.AddInstruction(
builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) {
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
std::vector<ReplicaGroup> replica_groups(1);
replica_groups[0].add_replica_ids(0);
replica_groups[0].add_replica_ids(1);
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a},
replica_groups));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction(), a2a);
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {0}).element_type(), BF16);
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {1}).element_type(), BF16);
EXPECT_EQ(a2a->operand(0)->opcode(), HloOpcode::kConvert);
EXPECT_EQ(a2a->operand(0)->shape().element_type(), BF16);
EXPECT_EQ(a2a->operand(1)->opcode(), HloOpcode::kConvert);
EXPECT_EQ(a2a->operand(1)->shape().element_type(), BF16);
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) {
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
std::vector<ReplicaGroup> replica_groups(1);
replica_groups[0].add_replica_ids(0);
replica_groups[0].add_replica_ids(1);
HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll(
ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a},
replica_groups));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple);
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {0}).element_type(), F32);
EXPECT_EQ(ShapeUtil::GetSubshape(a2a->shape(), {1}).element_type(), F32);
EXPECT_EQ(a2a->operand(0)->opcode(), HloOpcode::kParameter);
EXPECT_EQ(a2a->operand(0)->shape().element_type(), F32);
EXPECT_EQ(a2a->operand(1)->opcode(), HloOpcode::kParameter);
EXPECT_EQ(a2a->operand(1)->shape().element_type(), F32);
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
@ -288,15 +347,14 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
{key, value}, 0, /*is_stable=*/false, &builder,
module.get()));
HloInstruction* gte = builder.AddInstruction(
builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16);
EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
}