From 97039a80b3dabb5ed2e4fb5d0d0bdc5229293718 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Tue, 4 Sep 2018 13:56:42 -0700 Subject: [PATCH] Fix CRS combiner for spatial partitioning PiperOrigin-RevId: 211519250 --- .../compiler/xla/service/hlo_domain_map.cc | 41 +++++++++++++++++++ .../compiler/xla/service/hlo_domain_map.h | 10 +++++ .../xla/service/hlo_domain_metadata.h | 3 ++ .../compiler/xla/service/hlo_domain_test.cc | 2 + .../xla/service/hlo_sharding_metadata.cc | 7 ++++ .../xla/service/hlo_sharding_metadata.h | 2 + 6 files changed, 65 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 8b2846e0c27..113fd18eae7 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } +int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { + return FindOrDie(domain_metadata_id_, instruction); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) { CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } + TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + return Status::OK(); +} + +Status HloDomainMap::PopulateDomainMetadataMap() { + auto hash = [](const DomainMetadata* m) { return m->Hash(); }; + auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { + return a->Matches(*b); + }; + tensorflow::gtl::FlatMap + domain_metadata(1024, hash, equal); + + for (auto& domain : instruction_domains_) { + int64 domain_metadata_id = -1; + if (!domain->enter_domains.empty()) { + const HloInstruction* domain_instruction = *domain->enter_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->user_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else if (!domain->exit_domains.empty()) { + const HloInstruction* domain_instruction = *domain->exit_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->operand_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else { + domain_metadata_id = 0; + } + TF_RET_CHECK(domain_metadata_id >= 0); + for (HloInstruction* instruction : domain->instructions) { + domain_metadata_id_[instruction] = domain_metadata_id; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 633109249a9..56b557d7cea 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -69,6 +69,11 @@ class HloDomainMap { // instruction is not found within any domain. int64 GetDomainId(HloInstruction* instruction) const; + // Returns the unique id of the domain metadata for the domain the given + // instruction belongs to. The given instruction must not be a kDomain + // instruction since each domain instruction is associated with 2 domains. + int64 GetDomainMetadataId(HloInstruction* instruction) const; + private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. @@ -109,9 +114,14 @@ class HloDomainMap { const tensorflow::gtl::FlatSet& instruction_set, const InstructionOrderMap& instructions_order); + // Populates domain_metadata_id_ that maps each HloInstruction to the unique + // ID of its associated domain metatadata. + Status PopulateDomainMetadataMap(); + string domain_kind_; std::vector> instruction_domains_; tensorflow::gtl::FlatMap instruction_to_domain_; + tensorflow::gtl::FlatMap domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 6c142ee4742..302807f816e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -72,6 +72,9 @@ class DomainMetadata { // two matches. virtual bool Matches(const DomainMetadata& other) const = 0; + // Returns the hash value of the metadata. + virtual size_t Hash() const = 0; + // Returns a string representation of the metadata. virtual string ToString() const = 0; }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 974ab94467d..43e74d2f6f0 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata { static absl::string_view KindName() { return "opname"; } + size_t Hash() const override { return std::hash()(opname_); } + private: string opname_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 34cba6136ff..e3f4a9852ac 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { : false; } +size_t ShardingMetadata::Hash() const { + if (sharding_ != nullptr) { + return sharding_->Hash(); + } + return static_cast(0x297814aaad196e6dULL); +} + string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index cba5db927a0..e3ae82a0706 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata { bool Matches(const DomainMetadata& other) const override; + size_t Hash() const override; + string ToString() const override; const HloSharding* sharding() const { return sharding_.get(); }