Fix CRS combiner for spatial partitioning

PiperOrigin-RevId: 211519250
This commit is contained in:
HyoukJoong Lee 2018-09-04 13:56:42 -07:00 committed by TensorFlower Gardener
parent ffd9519c3f
commit 97039a80b3
6 changed files with 65 additions and 0 deletions

View File

@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
return FindOrDie(domain_metadata_id_, instruction);
}
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) {
CreateDomain(instruction, instructions_post_order));
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
return Status::OK();
}
Status HloDomainMap::PopulateDomainMetadataMap() {
auto hash = [](const DomainMetadata* m) { return m->Hash(); };
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
return a->Matches(*b);
};
tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
decltype(equal)>
domain_metadata(1024, hash, equal);
for (auto& domain : instruction_domains_) {
int64 domain_metadata_id = -1;
if (!domain->enter_domains.empty()) {
const HloInstruction* domain_instruction = *domain->enter_domains.begin();
domain_metadata_id =
domain_metadata
.insert({&domain_instruction->user_side_metadata(),
domain_metadata.size() + 1})
.first->second;
} else if (!domain->exit_domains.empty()) {
const HloInstruction* domain_instruction = *domain->exit_domains.begin();
domain_metadata_id =
domain_metadata
.insert({&domain_instruction->operand_side_metadata(),
domain_metadata.size() + 1})
.first->second;
} else {
domain_metadata_id = 0;
}
TF_RET_CHECK(domain_metadata_id >= 0);
for (HloInstruction* instruction : domain->instructions) {
domain_metadata_id_[instruction] = domain_metadata_id;
}
}
return Status::OK();
}

View File

@ -69,6 +69,11 @@ class HloDomainMap {
// instruction is not found within any domain.
int64 GetDomainId(HloInstruction* instruction) const;
// Returns the unique id of the domain metadata for the domain the given
// instruction belongs to. The given instruction must not be a kDomain
// instruction since each domain instruction is associated with 2 domains.
int64 GetDomainMetadataId(HloInstruction* instruction) const;
private:
// Map used for representing instruction ordering, i.e.
// order_map[a] < order_map[b] means a must be ordered before b.
@ -109,9 +114,14 @@ class HloDomainMap {
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order);
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
// ID of its associated domain metatadata.
Status PopulateDomainMetadataMap();
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla

View File

@ -72,6 +72,9 @@ class DomainMetadata {
// two matches.
virtual bool Matches(const DomainMetadata& other) const = 0;
// Returns the hash value of the metadata.
virtual size_t Hash() const = 0;
// Returns a string representation of the metadata.
virtual string ToString() const = 0;
};

View File

@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata {
static absl::string_view KindName() { return "opname"; }
size_t Hash() const override { return std::hash<string>()(opname_); }
private:
string opname_;
};

View File

@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
: false;
}
size_t ShardingMetadata::Hash() const {
if (sharding_ != nullptr) {
return sharding_->Hash();
}
return static_cast<size_t>(0x297814aaad196e6dULL);
}
string ShardingMetadata::ToString() const {
return sharding_ != nullptr ? sharding_->ToString() : "{}";
}

View File

@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata {
bool Matches(const DomainMetadata& other) const override;
size_t Hash() const override;
string ToString() const override;
const HloSharding* sharding() const { return sharding_.get(); }