[XLA] Improve scatter cost analysis.

Scatter needs to access input, output, updates, and indices.

PiperOrigin-RevId: 239091685
This commit is contained in:
A. Unique TensorFlower 2019-03-18 17:03:20 -07:00 committed by TensorFlower Gardener
parent ba994787e6
commit ba3e07c51e
2 changed files with 4 additions and 2 deletions

View File

@ -718,8 +718,10 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
}
Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
// Scatter accesses the equivalent of 3 update shapes (input, output, and
// updates), and the scatter indices.
current_properties_[kBytesAccessedKey] =
GetShapeSize(scatter->operand(2)->shape()) * 2 +
GetShapeSize(scatter->operand(2)->shape()) * 3 +
GetShapeSize(scatter->operand(1)->shape());
const int64 element_count =
ShapeUtil::ElementsIn(scatter->operand(2)->shape());

View File

@ -688,7 +688,7 @@ TEST_F(HloCostAnalysisTest, Scatter) {
ASSERT_IS_OK(
hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * (2 * 3)));
EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 3 * (2 * 3)));
}
} // namespace
} // namespace xla