[XLA:TPU] Move per-memory-space bytes read/written code to HloCostAnalysis.

PiperOrigin-RevId: 313284279
Change-Id: I544c7089c51cb4dad733732149e5bb8fb3b05fa9
This commit is contained in:
Berkin Ilbeyi 2020-05-26 15:59:48 -07:00 committed by TensorFlower Gardener
parent bba3595ebf
commit fe523d826d
2 changed files with 44 additions and 0 deletions

View File

@ -1041,6 +1041,42 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
} }
int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
absl::optional<int64> memory_space) const {
int64 bytes_read = 0;
for (int operand_number = 0; operand_number < hlo.operand_count();
++operand_number) {
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
absl::optional<int64> index_memory_space;
if (indexed_shape.shape.has_layout()) {
index_memory_space = indexed_shape.shape.layout().memory_space();
}
if (!memory_space || memory_space == index_memory_space) {
bytes_read +=
operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
}
}
}
return bytes_read;
}
int64 HloCostAnalysis::GetBytesWritten(
const HloInstruction& hlo, absl::optional<int64> memory_space) const {
int64 bytes_written = 0;
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(hlo.shape())) {
absl::optional<int64> index_memory_space;
if (indexed_shape.shape.has_layout()) {
index_memory_space = indexed_shape.shape.layout().memory_space();
}
if (!memory_space || memory_space == index_memory_space) {
bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
}
}
return bytes_written;
}
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation( StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
HloComputation* computation) { HloComputation* computation) {
auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_); auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);

View File

@ -164,6 +164,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
ShapeIndex index = {}) const; ShapeIndex index = {}) const;
float optimal_seconds(const HloInstruction& hlo) const; float optimal_seconds(const HloInstruction& hlo) const;
// Get bytes read/written by this HLO. If memory_space is provided, it returns
// the bytes read/written from/to the given memory space only.
int64 GetBytesRead(const HloInstruction& hlo,
absl::optional<int64> memory_space = absl::nullopt) const;
int64 GetBytesWritten(
const HloInstruction& hlo,
absl::optional<int64> memory_space = absl::nullopt) const;
const Properties& properties() const { return properties_sum_; } const Properties& properties() const { return properties_sum_; }
const float property(const string& key) const { const float property(const string& key) const {
return GetProperty(key, properties()); return GetProperty(key, properties());