[XLA] Uniquify sharding metadata when propagated.
Having the same metadata multiple times doesn't add any additional information and is a side-effect of merging it without checking for duplicates. Update the tests that relied on that. PiperOrigin-RevId: 354116186 Change-Id: Ie0e5544a913480076bfee11dd04126c36ce14c6c
This commit is contained in:
parent
5f3ab2cb75
commit
551297fd69
@ -264,6 +264,7 @@ cc_library(
|
|||||||
":types",
|
":types",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/hash",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||||
|
|
||||||
|
#include "absl/hash/hash.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
@ -38,6 +39,15 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
|||||||
return (serialized1 == serialized2);
|
return (serialized1 == serialized2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ProtobufHash(const tensorflow::protobuf::Message& m) {
|
||||||
|
// This is a bit fast and loose, but avoids introducing a dependency on
|
||||||
|
// the much more complex protobuf::util::MessageDifferencer class.
|
||||||
|
// We perform the hash on their serialized representation.
|
||||||
|
string serialized;
|
||||||
|
m.AppendToString(&serialized);
|
||||||
|
return absl::Hash<string>()(serialized);
|
||||||
|
}
|
||||||
|
|
||||||
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
|
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
|
||||||
const string& directory, const string& file_name,
|
const string& directory, const string& file_name,
|
||||||
string* full_path) {
|
string* full_path) {
|
||||||
|
@ -33,6 +33,27 @@ namespace protobuf_util {
|
|||||||
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
||||||
const tensorflow::protobuf::Message& m2);
|
const tensorflow::protobuf::Message& m2);
|
||||||
|
|
||||||
|
// Return the hash of the message "m".
|
||||||
|
//
|
||||||
|
// WARNING: This uses the same serialization approach used by ProtobufEquals,
|
||||||
|
// so the WARNING for that function applies here.
|
||||||
|
size_t ProtobufHash(const tensorflow::protobuf::Message& m);
|
||||||
|
|
||||||
|
// Wrappers for above methods so that they can be used in containers.
|
||||||
|
class ProtobufEqualsWrapper {
|
||||||
|
public:
|
||||||
|
bool operator()(const tensorflow::protobuf::Message& m1,
|
||||||
|
const tensorflow::protobuf::Message& m2) const {
|
||||||
|
return ProtobufEquals(m1, m2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProtobufHashWrapper {
|
||||||
|
public:
|
||||||
|
size_t operator()(const tensorflow::protobuf::Message& m) const {
|
||||||
|
return ProtobufHash(m);
|
||||||
|
}
|
||||||
|
};
|
||||||
// Writes the given message in binary proto to the path formed by joining
|
// Writes the given message in binary proto to the path formed by joining
|
||||||
// 'directory/file_name.pb'. The 'directory' is recursively created if it
|
// 'directory/file_name.pb'. The 'directory' is recursively created if it
|
||||||
// doesn't already exist, and the 'file_name' is sanitized by replacing
|
// doesn't already exist, and the 'file_name' is sanitized by replacing
|
||||||
|
@ -212,12 +212,15 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
|||||||
new_group_members[new_group_id].erase(*device);
|
new_group_members[new_group_id].erase(*device);
|
||||||
});
|
});
|
||||||
if (compatible) {
|
if (compatible) {
|
||||||
std::vector<OpMetadata> merged_metadata;
|
std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
|
||||||
std::swap(merged_metadata, to_merge->metadata());
|
merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
|
||||||
merged_metadata.reserve(to_merge->metadata().size() +
|
const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
|
||||||
old.metadata().size());
|
protobuf_util::ProtobufEqualsWrapper>
|
||||||
merged_metadata.insert(merged_metadata.end(), old.metadata().begin(),
|
metadata_set(merged_metadata.begin(), merged_metadata.end());
|
||||||
old.metadata().end());
|
absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
|
||||||
|
[&metadata_set](const OpMetadata& data) {
|
||||||
|
return !ContainsKey(metadata_set, data);
|
||||||
|
});
|
||||||
if (replication == 1) {
|
if (replication == 1) {
|
||||||
new_tile_dims.pop_back();
|
new_tile_dims.pop_back();
|
||||||
new_tile.Reshape(new_tile_dims);
|
new_tile.Reshape(new_tile_dims);
|
||||||
|
@ -4140,11 +4140,9 @@ ENTRY entry {
|
|||||||
op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||||
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
||||||
EXPECT_THAT(lhs->sharding(),
|
EXPECT_THAT(lhs->sharding(),
|
||||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
|
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
||||||
CreateMetadata("a")}));
|
|
||||||
EXPECT_THAT(rhs->sharding(),
|
EXPECT_THAT(rhs->sharding(),
|
||||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a"),
|
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
||||||
CreateMetadata("b")}));
|
|
||||||
EXPECT_THAT(add->sharding(),
|
EXPECT_THAT(add->sharding(),
|
||||||
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")}));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user