[XLA:TPU] Move per-memory-space bytes read/written code to HloCostAnalysis.
PiperOrigin-RevId: 313284279 Change-Id: I544c7089c51cb4dad733732149e5bb8fb3b05fa9
This commit is contained in:
parent
bba3595ebf
commit
fe523d826d
@ -1041,6 +1041,42 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
|
||||
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
|
||||
absl::optional<int64> memory_space) const {
|
||||
int64 bytes_read = 0;
|
||||
for (int operand_number = 0; operand_number < hlo.operand_count();
|
||||
++operand_number) {
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
|
||||
absl::optional<int64> index_memory_space;
|
||||
if (indexed_shape.shape.has_layout()) {
|
||||
index_memory_space = indexed_shape.shape.layout().memory_space();
|
||||
}
|
||||
if (!memory_space || memory_space == index_memory_space) {
|
||||
bytes_read +=
|
||||
operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
int64 HloCostAnalysis::GetBytesWritten(
|
||||
const HloInstruction& hlo, absl::optional<int64> memory_space) const {
|
||||
int64 bytes_written = 0;
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(hlo.shape())) {
|
||||
absl::optional<int64> index_memory_space;
|
||||
if (indexed_shape.shape.has_layout()) {
|
||||
index_memory_space = indexed_shape.shape.layout().memory_space();
|
||||
}
|
||||
if (!memory_space || memory_space == index_memory_space) {
|
||||
bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
|
||||
}
|
||||
}
|
||||
return bytes_written;
|
||||
}
|
||||
|
||||
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
|
||||
HloComputation* computation) {
|
||||
auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
|
||||
|
@ -164,6 +164,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
ShapeIndex index = {}) const;
|
||||
float optimal_seconds(const HloInstruction& hlo) const;
|
||||
|
||||
// Get bytes read/written by this HLO. If memory_space is provided, it returns
|
||||
// the bytes read/written from/to the given memory space only.
|
||||
int64 GetBytesRead(const HloInstruction& hlo,
|
||||
absl::optional<int64> memory_space = absl::nullopt) const;
|
||||
int64 GetBytesWritten(
|
||||
const HloInstruction& hlo,
|
||||
absl::optional<int64> memory_space = absl::nullopt) const;
|
||||
|
||||
const Properties& properties() const { return properties_sum_; }
|
||||
const float property(const string& key) const {
|
||||
return GetProperty(key, properties());
|
||||
|
Loading…
Reference in New Issue
Block a user