From 551297fd69c6efcf737821cd70d5f7f242542f76 Mon Sep 17 00:00:00 2001
From: Marcello Maggioni <maggioni@google.com>
Date: Wed, 27 Jan 2021 10:08:45 -0800
Subject: [PATCH] [XLA] Uniquify sharding metadata when propagated.

Having the same metadata multiple times doesn't add any additional
information and is a side-effect of merging it without checking for
duplicates.

Update the tests that relied on that.

PiperOrigin-RevId: 354116186
Change-Id: Ie0e5544a913480076bfee11dd04126c36ce14c6c
---
 tensorflow/compiler/xla/BUILD                 |  1 +
 tensorflow/compiler/xla/protobuf_util.cc      | 10 +++++++++
 tensorflow/compiler/xla/protobuf_util.h       | 21 +++++++++++++++++++
 .../xla/service/sharding_propagation.cc       | 15 +++++++------
 .../xla/service/sharding_propagation_test.cc  |  6 ++----
 5 files changed, 43 insertions(+), 10 deletions(-)

diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 0643af0ba48..f9d0f667006 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -264,6 +264,7 @@ cc_library(
         ":types",
         ":util",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/hash",
         "@com_google_absl//absl/time",
     ],
 )
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index b7c30531923..036bc09ab4d 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/protobuf_util.h"
 
+#include "absl/hash/hash.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -38,6 +39,15 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
   return (serialized1 == serialized2);
 }
 
+size_t ProtobufHash(const tensorflow::protobuf::Message& m) {
+  // This is a bit fast and loose, but avoids introducing a dependency on
+  // the much more complex protobuf::util::MessageDifferencer class.
+  // We perform the hash on their serialized representation.
+  string serialized;
+  m.AppendToString(&serialized);
+  return absl::Hash<string>()(serialized);
+}
+
 Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
                             const string& directory, const string& file_name,
                             string* full_path) {
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index 7db020982b9..29fea36a041 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -33,6 +33,27 @@ namespace protobuf_util {
 extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
                            const tensorflow::protobuf::Message& m2);
 
+// Return the hash of the message "m".
+//
+// WARNING: This uses the same serialization approach used by ProtobufEquals,
+// so the WARNING for that function applies here.
+size_t ProtobufHash(const tensorflow::protobuf::Message& m);
+
+// Wrappers for above methods so that they can be used in containers.
+class ProtobufEqualsWrapper {
+ public:
+  bool operator()(const tensorflow::protobuf::Message& m1,
+                  const tensorflow::protobuf::Message& m2) const {
+    return ProtobufEquals(m1, m2);
+  }
+};
+
+class ProtobufHashWrapper {
+ public:
+  size_t operator()(const tensorflow::protobuf::Message& m) const {
+    return ProtobufHash(m);
+  }
+};
 // Writes the given message in binary proto to the path formed by joining
 // 'directory/file_name.pb'. The 'directory' is recursively created if it
 // doesn't already exist, and the 'file_name' is sanitized by replacing
diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc
index a032d33d4b5..af338f5e1fa 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation.cc
@@ -212,12 +212,15 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
     new_group_members[new_group_id].erase(*device);
   });
   if (compatible) {
-    std::vector<OpMetadata> merged_metadata;
-    std::swap(merged_metadata, to_merge->metadata());
-    merged_metadata.reserve(to_merge->metadata().size() +
-                            old.metadata().size());
-    merged_metadata.insert(merged_metadata.end(), old.metadata().begin(),
-                           old.metadata().end());
+    std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
+    merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
+    const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
+                              protobuf_util::ProtobufEqualsWrapper>
+        metadata_set(merged_metadata.begin(), merged_metadata.end());
+    absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
+                    [&metadata_set](const OpMetadata& data) {
+                      return !ContainsKey(metadata_set, data);
+                    });
     if (replication == 1) {
       new_tile_dims.pop_back();
       new_tile.Reshape(new_tile_dims);
diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
index 2dfbc20d741..85190ac41b4 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
@@ -4140,11 +4140,9 @@ ENTRY entry {
       op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
   if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
     EXPECT_THAT(lhs->sharding(),
-                ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
-                                  CreateMetadata("a")}));
+                ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
     EXPECT_THAT(rhs->sharding(),
-                ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
-                                  CreateMetadata("b")}));
+                ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
     EXPECT_THAT(add->sharding(),
                 ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));