Fix CRS combiner for spatial partitioning
PiperOrigin-RevId: 211519250
This commit is contained in:
parent
ffd9519c3f
commit
97039a80b3
@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
|
||||
return FindOrDefault(instruction_to_domain_, instruction, -1);
|
||||
}
|
||||
|
||||
int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
|
||||
return FindOrDie(domain_metadata_id_, instruction);
|
||||
}
|
||||
|
||||
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
|
||||
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
|
||||
// We only check operands, so we are sure to not process the empty domain from
|
||||
@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) {
|
||||
CreateDomain(instruction, instructions_post_order));
|
||||
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloDomainMap::PopulateDomainMetadataMap() {
|
||||
auto hash = [](const DomainMetadata* m) { return m->Hash(); };
|
||||
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
|
||||
return a->Matches(*b);
|
||||
};
|
||||
tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
|
||||
decltype(equal)>
|
||||
domain_metadata(1024, hash, equal);
|
||||
|
||||
for (auto& domain : instruction_domains_) {
|
||||
int64 domain_metadata_id = -1;
|
||||
if (!domain->enter_domains.empty()) {
|
||||
const HloInstruction* domain_instruction = *domain->enter_domains.begin();
|
||||
domain_metadata_id =
|
||||
domain_metadata
|
||||
.insert({&domain_instruction->user_side_metadata(),
|
||||
domain_metadata.size() + 1})
|
||||
.first->second;
|
||||
} else if (!domain->exit_domains.empty()) {
|
||||
const HloInstruction* domain_instruction = *domain->exit_domains.begin();
|
||||
domain_metadata_id =
|
||||
domain_metadata
|
||||
.insert({&domain_instruction->operand_side_metadata(),
|
||||
domain_metadata.size() + 1})
|
||||
.first->second;
|
||||
} else {
|
||||
domain_metadata_id = 0;
|
||||
}
|
||||
TF_RET_CHECK(domain_metadata_id >= 0);
|
||||
for (HloInstruction* instruction : domain->instructions) {
|
||||
domain_metadata_id_[instruction] = domain_metadata_id;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -69,6 +69,11 @@ class HloDomainMap {
|
||||
// instruction is not found within any domain.
|
||||
int64 GetDomainId(HloInstruction* instruction) const;
|
||||
|
||||
// Returns the unique id of the domain metadata for the domain the given
|
||||
// instruction belongs to. The given instruction must not be a kDomain
|
||||
// instruction since each domain instruction is associated with 2 domains.
|
||||
int64 GetDomainMetadataId(HloInstruction* instruction) const;
|
||||
|
||||
private:
|
||||
// Map used for representing instruction ordering, i.e.
|
||||
// order_map[a] < order_map[b] means a must be ordered before b.
|
||||
@ -109,9 +114,14 @@ class HloDomainMap {
|
||||
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
|
||||
const InstructionOrderMap& instructions_order);
|
||||
|
||||
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
|
||||
// ID of its associated domain metatadata.
|
||||
Status PopulateDomainMetadataMap();
|
||||
|
||||
string domain_kind_;
|
||||
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
|
||||
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
|
||||
tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -72,6 +72,9 @@ class DomainMetadata {
|
||||
// two matches.
|
||||
virtual bool Matches(const DomainMetadata& other) const = 0;
|
||||
|
||||
// Returns the hash value of the metadata.
|
||||
virtual size_t Hash() const = 0;
|
||||
|
||||
// Returns a string representation of the metadata.
|
||||
virtual string ToString() const = 0;
|
||||
};
|
||||
|
@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata {
|
||||
|
||||
static absl::string_view KindName() { return "opname"; }
|
||||
|
||||
size_t Hash() const override { return std::hash<string>()(opname_); }
|
||||
|
||||
private:
|
||||
string opname_;
|
||||
};
|
||||
|
@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
|
||||
: false;
|
||||
}
|
||||
|
||||
size_t ShardingMetadata::Hash() const {
|
||||
if (sharding_ != nullptr) {
|
||||
return sharding_->Hash();
|
||||
}
|
||||
return static_cast<size_t>(0x297814aaad196e6dULL);
|
||||
}
|
||||
|
||||
string ShardingMetadata::ToString() const {
|
||||
return sharding_ != nullptr ? sharding_->ToString() : "{}";
|
||||
}
|
||||
|
@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata {
|
||||
|
||||
bool Matches(const DomainMetadata& other) const override;
|
||||
|
||||
size_t Hash() const override;
|
||||
|
||||
string ToString() const override;
|
||||
|
||||
const HloSharding* sharding() const { return sharding_.get(); }
|
||||
|
Loading…
x
Reference in New Issue
Block a user