[XLA] Uniquify sharding metadata when propagated.

Having the same metadata multiple times doesn't add any additional
information and is a side-effect of merging it without checking for
duplicates.

Update the tests that relied on that.

PiperOrigin-RevId: 354116186
Change-Id: Ie0e5544a913480076bfee11dd04126c36ce14c6c
This commit is contained in:
Marcello Maggioni 2021-01-27 10:08:45 -08:00 committed by TensorFlower Gardener
parent 5f3ab2cb75
commit 551297fd69
5 changed files with 43 additions and 10 deletions

View File

@ -264,6 +264,7 @@ cc_library(
":types",
":util",
"//tensorflow/core:lib",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/time",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "absl/hash/hash.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@ -38,6 +39,15 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
return (serialized1 == serialized2);
}
size_t ProtobufHash(const tensorflow::protobuf::Message& m) {
// This is a bit fast and loose, but avoids introducing a dependency on
// the much more complex protobuf::util::MessageDifferencer class.
// We perform the hash on their serialized representation.
string serialized;
m.AppendToString(&serialized);
return absl::Hash<string>()(serialized);
}
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name,
string* full_path) {

View File

@ -33,6 +33,27 @@ namespace protobuf_util {
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
const tensorflow::protobuf::Message& m2);
// Return the hash of the message "m".
//
// WARNING: This uses the same serialization approach used by ProtobufEquals,
// so the WARNING for that function applies here.
size_t ProtobufHash(const tensorflow::protobuf::Message& m);
// Wrappers for above methods so that they can be used in containers.
class ProtobufEqualsWrapper {
public:
bool operator()(const tensorflow::protobuf::Message& m1,
const tensorflow::protobuf::Message& m2) const {
return ProtobufEquals(m1, m2);
}
};
class ProtobufHashWrapper {
public:
size_t operator()(const tensorflow::protobuf::Message& m) const {
return ProtobufHash(m);
}
};
// Writes the given message in binary proto to the path formed by joining
// 'directory/file_name.pb'. The 'directory' is recursively created if it
// doesn't already exist, and the 'file_name' is sanitized by replacing

View File

@ -212,12 +212,15 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
new_group_members[new_group_id].erase(*device);
});
if (compatible) {
std::vector<OpMetadata> merged_metadata;
std::swap(merged_metadata, to_merge->metadata());
merged_metadata.reserve(to_merge->metadata().size() +
old.metadata().size());
merged_metadata.insert(merged_metadata.end(), old.metadata().begin(),
old.metadata().end());
std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
protobuf_util::ProtobufEqualsWrapper>
metadata_set(merged_metadata.begin(), merged_metadata.end());
absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
[&metadata_set](const OpMetadata& data) {
return !ContainsKey(metadata_set, data);
});
if (replication == 1) {
new_tile_dims.pop_back();
new_tile.Reshape(new_tile_dims);

View File

@ -4140,11 +4140,9 @@ ENTRY entry {
op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(lhs->sharding(),
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
CreateMetadata("a")}));
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
EXPECT_THAT(rhs->sharding(),
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
CreateMetadata("b")}));
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
EXPECT_THAT(add->sharding(),
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));