[XLA] Uniquify sharding metadata when propagated.
Having the same metadata multiple times doesn't add any additional information and is a side-effect of merging it without checking for duplicates. Update the tests that relied on that. PiperOrigin-RevId: 354116186 Change-Id: Ie0e5544a913480076bfee11dd04126c36ce14c6c
This commit is contained in:
parent
5f3ab2cb75
commit
551297fd69
tensorflow/compiler/xla
@ -264,6 +264,7 @@ cc_library(
|
||||
":types",
|
||||
":util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
|
||||
#include "absl/hash/hash.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -38,6 +39,15 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
||||
return (serialized1 == serialized2);
|
||||
}
|
||||
|
||||
size_t ProtobufHash(const tensorflow::protobuf::Message& m) {
|
||||
// This is a bit fast and loose, but avoids introducing a dependency on
|
||||
// the much more complex protobuf::util::MessageDifferencer class.
|
||||
// We perform the hash on their serialized representation.
|
||||
string serialized;
|
||||
m.AppendToString(&serialized);
|
||||
return absl::Hash<string>()(serialized);
|
||||
}
|
||||
|
||||
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
|
||||
const string& directory, const string& file_name,
|
||||
string* full_path) {
|
||||
|
@ -33,6 +33,27 @@ namespace protobuf_util {
|
||||
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
||||
const tensorflow::protobuf::Message& m2);
|
||||
|
||||
// Return the hash of the message "m".
|
||||
//
|
||||
// WARNING: This uses the same serialization approach used by ProtobufEquals,
|
||||
// so the WARNING for that function applies here.
|
||||
size_t ProtobufHash(const tensorflow::protobuf::Message& m);
|
||||
|
||||
// Wrappers for above methods so that they can be used in containers.
|
||||
class ProtobufEqualsWrapper {
|
||||
public:
|
||||
bool operator()(const tensorflow::protobuf::Message& m1,
|
||||
const tensorflow::protobuf::Message& m2) const {
|
||||
return ProtobufEquals(m1, m2);
|
||||
}
|
||||
};
|
||||
|
||||
class ProtobufHashWrapper {
|
||||
public:
|
||||
size_t operator()(const tensorflow::protobuf::Message& m) const {
|
||||
return ProtobufHash(m);
|
||||
}
|
||||
};
|
||||
// Writes the given message in binary proto to the path formed by joining
|
||||
// 'directory/file_name.pb'. The 'directory' is recursively created if it
|
||||
// doesn't already exist, and the 'file_name' is sanitized by replacing
|
||||
|
@ -212,12 +212,15 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
||||
new_group_members[new_group_id].erase(*device);
|
||||
});
|
||||
if (compatible) {
|
||||
std::vector<OpMetadata> merged_metadata;
|
||||
std::swap(merged_metadata, to_merge->metadata());
|
||||
merged_metadata.reserve(to_merge->metadata().size() +
|
||||
old.metadata().size());
|
||||
merged_metadata.insert(merged_metadata.end(), old.metadata().begin(),
|
||||
old.metadata().end());
|
||||
std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
|
||||
merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
|
||||
const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
|
||||
protobuf_util::ProtobufEqualsWrapper>
|
||||
metadata_set(merged_metadata.begin(), merged_metadata.end());
|
||||
absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
|
||||
[&metadata_set](const OpMetadata& data) {
|
||||
return !ContainsKey(metadata_set, data);
|
||||
});
|
||||
if (replication == 1) {
|
||||
new_tile_dims.pop_back();
|
||||
new_tile.Reshape(new_tile_dims);
|
||||
|
@ -4140,11 +4140,9 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
||||
EXPECT_THAT(lhs->sharding(),
|
||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
|
||||
CreateMetadata("a")}));
|
||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
||||
EXPECT_THAT(rhs->sharding(),
|
||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
|
||||
CreateMetadata("b")}));
|
||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
||||
EXPECT_THAT(add->sharding(),
|
||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user