[XLA:TPU] Move per-memory-space bytes read/written code to HloCostAnalysis.
PiperOrigin-RevId: 313284279 Change-Id: I544c7089c51cb4dad733732149e5bb8fb3b05fa9
This commit is contained in:
parent
bba3595ebf
commit
fe523d826d
@ -1041,6 +1041,42 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
|
|||||||
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
|
return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
|
||||||
|
absl::optional<int64> memory_space) const {
|
||||||
|
int64 bytes_read = 0;
|
||||||
|
for (int operand_number = 0; operand_number < hlo.operand_count();
|
||||||
|
++operand_number) {
|
||||||
|
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||||
|
ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
|
||||||
|
absl::optional<int64> index_memory_space;
|
||||||
|
if (indexed_shape.shape.has_layout()) {
|
||||||
|
index_memory_space = indexed_shape.shape.layout().memory_space();
|
||||||
|
}
|
||||||
|
if (!memory_space || memory_space == index_memory_space) {
|
||||||
|
bytes_read +=
|
||||||
|
operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bytes_read;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 HloCostAnalysis::GetBytesWritten(
|
||||||
|
const HloInstruction& hlo, absl::optional<int64> memory_space) const {
|
||||||
|
int64 bytes_written = 0;
|
||||||
|
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||||
|
ShapeUtil::GetLeafShapes(hlo.shape())) {
|
||||||
|
absl::optional<int64> index_memory_space;
|
||||||
|
if (indexed_shape.shape.has_layout()) {
|
||||||
|
index_memory_space = indexed_shape.shape.layout().memory_space();
|
||||||
|
}
|
||||||
|
if (!memory_space || memory_space == index_memory_space) {
|
||||||
|
bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bytes_written;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
|
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
|
||||||
HloComputation* computation) {
|
HloComputation* computation) {
|
||||||
auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
|
auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
|
||||||
|
@ -164,6 +164,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
|||||||
ShapeIndex index = {}) const;
|
ShapeIndex index = {}) const;
|
||||||
float optimal_seconds(const HloInstruction& hlo) const;
|
float optimal_seconds(const HloInstruction& hlo) const;
|
||||||
|
|
||||||
|
// Get bytes read/written by this HLO. If memory_space is provided, it returns
|
||||||
|
// the bytes read/written from/to the given memory space only.
|
||||||
|
int64 GetBytesRead(const HloInstruction& hlo,
|
||||||
|
absl::optional<int64> memory_space = absl::nullopt) const;
|
||||||
|
int64 GetBytesWritten(
|
||||||
|
const HloInstruction& hlo,
|
||||||
|
absl::optional<int64> memory_space = absl::nullopt) const;
|
||||||
|
|
||||||
const Properties& properties() const { return properties_sum_; }
|
const Properties& properties() const { return properties_sum_; }
|
||||||
const float property(const string& key) const {
|
const float property(const string& key) const {
|
||||||
return GetProperty(key, properties());
|
return GetProperty(key, properties());
|
||||||
|
Loading…
Reference in New Issue
Block a user