From 551297fd69c6efcf737821cd70d5f7f242542f76 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni <maggioni@google.com> Date: Wed, 27 Jan 2021 10:08:45 -0800 Subject: [PATCH] [XLA] Uniquify sharding metadata when propagated. Having the same metadata multiple times doesn't add any additional information and is a side-effect of merging it without checking for duplicates. Update the tests that relied on that. PiperOrigin-RevId: 354116186 Change-Id: Ie0e5544a913480076bfee11dd04126c36ce14c6c --- tensorflow/compiler/xla/BUILD | 1 + tensorflow/compiler/xla/protobuf_util.cc | 10 +++++++++ tensorflow/compiler/xla/protobuf_util.h | 21 +++++++++++++++++++ .../xla/service/sharding_propagation.cc | 15 +++++++------ .../xla/service/sharding_propagation_test.cc | 6 ++---- 5 files changed, 43 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 0643af0ba48..f9d0f667006 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -264,6 +264,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/hash", "@com_google_absl//absl/time", ], ) diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index b7c30531923..036bc09ab4d 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" +#include "absl/hash/hash.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -38,6 +39,15 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } +size_t ProtobufHash(const tensorflow::protobuf::Message& m) { + // This is a bit fast and loose, but avoids introducing a dependency on + // the much more complex protobuf::util::MessageDifferencer class. + // We perform the hash on their serialized representation. + string serialized; + m.AppendToString(&serialized); + return absl::Hash<string>()(serialized); +} + Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name, string* full_path) { diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 7db020982b9..29fea36a041 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -33,6 +33,27 @@ namespace protobuf_util { extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, const tensorflow::protobuf::Message& m2); +// Return the hash of the message "m". +// +// WARNING: This uses the same serialization approach used by ProtobufEquals, +// so the WARNING for that function applies here. +size_t ProtobufHash(const tensorflow::protobuf::Message& m); + +// Wrappers for above methods so that they can be used in containers. +class ProtobufEqualsWrapper { + public: + bool operator()(const tensorflow::protobuf::Message& m1, + const tensorflow::protobuf::Message& m2) const { + return ProtobufEquals(m1, m2); + } +}; + +class ProtobufHashWrapper { + public: + size_t operator()(const tensorflow::protobuf::Message& m) const { + return ProtobufHash(m); + } +}; // Writes the given message in binary proto to the path formed by joining // 'directory/file_name.pb'. The 'directory' is recursively created if it // doesn't already exist, and the 'file_name' is sanitized by replacing diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index a032d33d4b5..af338f5e1fa 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -212,12 +212,15 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge, new_group_members[new_group_id].erase(*device); }); if (compatible) { - std::vector<OpMetadata> merged_metadata; - std::swap(merged_metadata, to_merge->metadata()); - merged_metadata.reserve(to_merge->metadata().size() + - old.metadata().size()); - merged_metadata.insert(merged_metadata.end(), old.metadata().begin(), - old.metadata().end()); + std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata())); + merged_metadata.reserve(merged_metadata.size() + old.metadata().size()); + const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper, + protobuf_util::ProtobufEqualsWrapper> + metadata_set(merged_metadata.begin(), merged_metadata.end()); + absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata), + [&metadata_set](const OpMetadata& data) { + return !ContainsKey(metadata_set, data); + }); if (replication == 1) { new_tile_dims.pop_back(); new_tile.Reshape(new_tile_dims); diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index 2dfbc20d741..85190ac41b4 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -4140,11 +4140,9 @@ ENTRY entry { op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(lhs->sharding(), - ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"), - CreateMetadata("a")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); EXPECT_THAT(rhs->sharding(), - ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"), - CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); EXPECT_THAT(add->sharding(), ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));