[XLA] Improve scatter cost analysis.
Scatter needs to access input, output, updates, and indices. PiperOrigin-RevId: 239091685
This commit is contained in:
parent
ba994787e6
commit
ba3e07c51e
@ -718,8 +718,10 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
|
||||
// Scatter accesses the equivalent of 3 update shapes (input, output, and
|
||||
// updates), and the scatter indices.
|
||||
current_properties_[kBytesAccessedKey] =
|
||||
GetShapeSize(scatter->operand(2)->shape()) * 2 +
|
||||
GetShapeSize(scatter->operand(2)->shape()) * 3 +
|
||||
GetShapeSize(scatter->operand(1)->shape());
|
||||
const int64 element_count =
|
||||
ShapeUtil::ElementsIn(scatter->operand(2)->shape());
|
||||
|
@ -688,7 +688,7 @@ TEST_F(HloCostAnalysisTest, Scatter) {
|
||||
ASSERT_IS_OK(
|
||||
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
|
||||
|
||||
EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * (2 * 3)));
|
||||
EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 3 * (2 * 3)));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user