Fix CRS combiner for spatial partitioning
PiperOrigin-RevId: 211519250
This commit is contained in:
parent
ffd9519c3f
commit
97039a80b3
@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
|
|||||||
return FindOrDefault(instruction_to_domain_, instruction, -1);
|
return FindOrDefault(instruction_to_domain_, instruction, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
|
||||||
|
return FindOrDie(domain_metadata_id_, instruction);
|
||||||
|
}
|
||||||
|
|
||||||
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
|
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
|
||||||
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
|
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
|
||||||
// We only check operands, so we are sure to not process the empty domain from
|
// We only check operands, so we are sure to not process the empty domain from
|
||||||
@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) {
|
|||||||
CreateDomain(instruction, instructions_post_order));
|
CreateDomain(instruction, instructions_post_order));
|
||||||
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
|
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
|
||||||
}
|
}
|
||||||
|
TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HloDomainMap::PopulateDomainMetadataMap() {
|
||||||
|
auto hash = [](const DomainMetadata* m) { return m->Hash(); };
|
||||||
|
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
|
||||||
|
return a->Matches(*b);
|
||||||
|
};
|
||||||
|
tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
|
||||||
|
decltype(equal)>
|
||||||
|
domain_metadata(1024, hash, equal);
|
||||||
|
|
||||||
|
for (auto& domain : instruction_domains_) {
|
||||||
|
int64 domain_metadata_id = -1;
|
||||||
|
if (!domain->enter_domains.empty()) {
|
||||||
|
const HloInstruction* domain_instruction = *domain->enter_domains.begin();
|
||||||
|
domain_metadata_id =
|
||||||
|
domain_metadata
|
||||||
|
.insert({&domain_instruction->user_side_metadata(),
|
||||||
|
domain_metadata.size() + 1})
|
||||||
|
.first->second;
|
||||||
|
} else if (!domain->exit_domains.empty()) {
|
||||||
|
const HloInstruction* domain_instruction = *domain->exit_domains.begin();
|
||||||
|
domain_metadata_id =
|
||||||
|
domain_metadata
|
||||||
|
.insert({&domain_instruction->operand_side_metadata(),
|
||||||
|
domain_metadata.size() + 1})
|
||||||
|
.first->second;
|
||||||
|
} else {
|
||||||
|
domain_metadata_id = 0;
|
||||||
|
}
|
||||||
|
TF_RET_CHECK(domain_metadata_id >= 0);
|
||||||
|
for (HloInstruction* instruction : domain->instructions) {
|
||||||
|
domain_metadata_id_[instruction] = domain_metadata_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,6 +69,11 @@ class HloDomainMap {
|
|||||||
// instruction is not found within any domain.
|
// instruction is not found within any domain.
|
||||||
int64 GetDomainId(HloInstruction* instruction) const;
|
int64 GetDomainId(HloInstruction* instruction) const;
|
||||||
|
|
||||||
|
// Returns the unique id of the domain metadata for the domain the given
|
||||||
|
// instruction belongs to. The given instruction must not be a kDomain
|
||||||
|
// instruction since each domain instruction is associated with 2 domains.
|
||||||
|
int64 GetDomainMetadataId(HloInstruction* instruction) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Map used for representing instruction ordering, i.e.
|
// Map used for representing instruction ordering, i.e.
|
||||||
// order_map[a] < order_map[b] means a must be ordered before b.
|
// order_map[a] < order_map[b] means a must be ordered before b.
|
||||||
@ -109,9 +114,14 @@ class HloDomainMap {
|
|||||||
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
|
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
|
||||||
const InstructionOrderMap& instructions_order);
|
const InstructionOrderMap& instructions_order);
|
||||||
|
|
||||||
|
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
|
||||||
|
// ID of its associated domain metatadata.
|
||||||
|
Status PopulateDomainMetadataMap();
|
||||||
|
|
||||||
string domain_kind_;
|
string domain_kind_;
|
||||||
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
|
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
|
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
|
||||||
|
tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -72,6 +72,9 @@ class DomainMetadata {
|
|||||||
// two matches.
|
// two matches.
|
||||||
virtual bool Matches(const DomainMetadata& other) const = 0;
|
virtual bool Matches(const DomainMetadata& other) const = 0;
|
||||||
|
|
||||||
|
// Returns the hash value of the metadata.
|
||||||
|
virtual size_t Hash() const = 0;
|
||||||
|
|
||||||
// Returns a string representation of the metadata.
|
// Returns a string representation of the metadata.
|
||||||
virtual string ToString() const = 0;
|
virtual string ToString() const = 0;
|
||||||
};
|
};
|
||||||
|
@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata {
|
|||||||
|
|
||||||
static absl::string_view KindName() { return "opname"; }
|
static absl::string_view KindName() { return "opname"; }
|
||||||
|
|
||||||
|
size_t Hash() const override { return std::hash<string>()(opname_); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string opname_;
|
string opname_;
|
||||||
};
|
};
|
||||||
|
@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
|
|||||||
: false;
|
: false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ShardingMetadata::Hash() const {
|
||||||
|
if (sharding_ != nullptr) {
|
||||||
|
return sharding_->Hash();
|
||||||
|
}
|
||||||
|
return static_cast<size_t>(0x297814aaad196e6dULL);
|
||||||
|
}
|
||||||
|
|
||||||
string ShardingMetadata::ToString() const {
|
string ShardingMetadata::ToString() const {
|
||||||
return sharding_ != nullptr ? sharding_->ToString() : "{}";
|
return sharding_ != nullptr ? sharding_->ToString() : "{}";
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata {
|
|||||||
|
|
||||||
bool Matches(const DomainMetadata& other) const override;
|
bool Matches(const DomainMetadata& other) const override;
|
||||||
|
|
||||||
|
size_t Hash() const override;
|
||||||
|
|
||||||
string ToString() const override;
|
string ToString() const override;
|
||||||
|
|
||||||
const HloSharding* sharding() const { return sharding_.get(); }
|
const HloSharding* sharding() const { return sharding_.get(); }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user