[XLA] Make HloCostAnalysis account for nested shapes to calculate bytes accessed in fusion, infeed and outfeed
PiperOrigin-RevId: 356785877 Change-Id: Ic482891d04b7dfad63be11d01c3ecc7dae69916a
This commit is contained in:
parent
b681785237
commit
2753c4ce37
@ -146,7 +146,8 @@ int64 HloCostAnalysis::FusionParameterReadBytes(
|
||||
const HloInstruction* hlo) const {
|
||||
int64 size = 0;
|
||||
bool seen_trivial_user = false;
|
||||
CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter);
|
||||
CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement));
|
||||
for (const HloInstruction* user : hlo->users()) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kFusion: {
|
||||
@ -335,11 +336,34 @@ Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
|
||||
Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
|
||||
// Count nested infeed output tuples.
|
||||
int64 size = 0;
|
||||
for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
|
||||
size += GetShapeSize(indexed_shape.shape);
|
||||
SetOutputBytesAccessed(indexed_shape.index,
|
||||
GetShapeSize(indexed_shape.shape));
|
||||
}
|
||||
SetOutputBytesAccessed(size);
|
||||
current_properties_[kBytesAccessedKey] = size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
|
||||
Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
|
||||
// Count nested outfeed operand tuples.
|
||||
current_properties_[kBytesAccessedKey] = 0;
|
||||
for (int64 i = 0; i < outfeed->operand_count(); ++i) {
|
||||
const HloInstruction* operand = outfeed->operand(i);
|
||||
int64 size = 0;
|
||||
for (const auto& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(operand->shape())) {
|
||||
size += GetShapeSize(indexed_shape.shape);
|
||||
SetOperandBytesAccessed(i, indexed_shape.index,
|
||||
GetShapeSize(indexed_shape.shape));
|
||||
}
|
||||
SetOperandBytesAccessed(i, size);
|
||||
current_properties_[kBytesAccessedKey] += size;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -872,9 +896,31 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
|
||||
|
||||
for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
|
||||
const HloInstruction* operand = fusion->fused_parameter(i);
|
||||
int64 size = FusionParameterReadBytes(operand);
|
||||
current_properties_[kBytesAccessedKey] += size;
|
||||
SetOperandBytesAccessed(i, size);
|
||||
int64 operand_size = 0;
|
||||
if (!fusion->shape().IsTuple()) {
|
||||
operand_size = FusionParameterReadBytes(operand);
|
||||
} else {
|
||||
// If the fusion parameter is a tuple type, find the gte for the leaf
|
||||
// shape and calculate the bytes accessed for those array types.
|
||||
for (const auto& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(operand->shape())) {
|
||||
const HloInstruction* gte = operand;
|
||||
for (int64 index : indexed_shape.index) {
|
||||
for (const HloInstruction* user : gte->users()) {
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement &&
|
||||
user->tuple_index() == index) {
|
||||
gte = user;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
int64 size = FusionParameterReadBytes(gte);
|
||||
operand_size += size;
|
||||
SetOperandBytesAccessed(i, indexed_shape.index, size);
|
||||
}
|
||||
}
|
||||
current_properties_[kBytesAccessedKey] += operand_size;
|
||||
SetOperandBytesAccessed(i, operand_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
@ -693,10 +693,10 @@ TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) {
|
||||
EXPECT_EQ(fusion_analysis.flop_count(), 16);
|
||||
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
|
||||
EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
|
||||
sizeof(float) * (3 + 5) * 2 * 2 + kPointerSize * 2);
|
||||
sizeof(float) * (5 + 5) * 2 * 2);
|
||||
|
||||
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
|
||||
kPointerSize * 2);
|
||||
sizeof(float) * 2 * 2 * 2);
|
||||
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
|
||||
sizeof(float) * 2 * 2);
|
||||
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
|
||||
@ -758,6 +758,78 @@ TEST_F(FusionCostAnalysis, NoLayout) {
|
||||
sizeof(float) * 2 * 3 * 4 * 5);
|
||||
}
|
||||
|
||||
TEST_F(FusionCostAnalysis, TupleBytesAccessed) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module, is_scheduled=true
|
||||
|
||||
fused_computation {
|
||||
param = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
|
||||
gte0 = f32[2,2]{1,0} get-tuple-element(param), index=0
|
||||
gte1 = f32[2,2]{1,0} get-tuple-element(param), index=1
|
||||
add = f32[2,2]{1,0} add(gte0, gte1)
|
||||
mul = f32[2,2]{1,0} multiply(gte0, gte1)
|
||||
ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[2,2]{1,0} parameter(0)
|
||||
param1 = f32[2,2]{1,0} parameter(1)
|
||||
tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(param0, param1)
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(tuple), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloInstruction* fusion = module->entry_computation()->root_instruction();
|
||||
|
||||
HloCostAnalysis fusion_analysis(ShapeSize);
|
||||
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
|
||||
|
||||
EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4);
|
||||
EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
|
||||
sizeof(float) * 2 * 2 * 2);
|
||||
EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
|
||||
sizeof(float) * 2 * 2 * 2);
|
||||
}
|
||||
|
||||
TEST_F(FusionCostAnalysis, InfeedOutfeed) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module, is_scheduled=true
|
||||
|
||||
ENTRY entry {
|
||||
after-all = token[] after-all()
|
||||
infeed = ((f32[2,3]{1,0}), token[]) infeed(after-all)
|
||||
gte0 = (f32[2,3]{1,0}) get-tuple-element(infeed), index=0
|
||||
gte1 = f32[2,3]{1,0} get-tuple-element(gte0), index=0
|
||||
add = f32[2,3]{1,0} add(gte1, gte1)
|
||||
tuple = (f32[2,3]{1,0}) tuple(add)
|
||||
tok = token[] get-tuple-element(infeed), index=1
|
||||
ROOT outfeed = token[] outfeed(tuple, tok)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloInstruction* infeed =
|
||||
module->entry_computation()->GetInstructionWithName("infeed");
|
||||
HloInstruction* outfeed =
|
||||
module->entry_computation()->GetInstructionWithName("outfeed");
|
||||
|
||||
HloCostAnalysis analysis(ShapeSize);
|
||||
ASSERT_IS_OK(infeed->Accept(&analysis));
|
||||
ASSERT_IS_OK(outfeed->Accept(&analysis));
|
||||
|
||||
EXPECT_EQ(analysis.bytes_accessed(*infeed), sizeof(float) * 2 * 3);
|
||||
EXPECT_EQ(analysis.operand_bytes_accessed(*infeed, 0), 0);
|
||||
EXPECT_EQ(analysis.output_bytes_accessed(*infeed), sizeof(float) * 2 * 3);
|
||||
|
||||
EXPECT_EQ(analysis.bytes_accessed(*outfeed), sizeof(float) * 2 * 3);
|
||||
EXPECT_EQ(analysis.operand_bytes_accessed(*outfeed, 0),
|
||||
sizeof(float) * 2 * 3);
|
||||
EXPECT_EQ(analysis.output_bytes_accessed(*outfeed), 0);
|
||||
}
|
||||
|
||||
TEST_F(HloCostAnalysisTest, TupleCost) {
|
||||
HloCostAnalysis analysis(ShapeSize);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user