[XLA] Add option to propagate sharding metadata in Hlo sharding propagation pass.

Sharding metadata is retained if metadata propagation is not allowed. Metadata is also assigned to sharding (with no metadata) attached to instructions using the instruction's metadata if available.

Sharding propagation tests are updated to be parameterized, to test sharding metadata propagation.

PiperOrigin-RevId: 353716863
Change-Id: I55e79e6dec7eea5a2cd7ca718d861e17ba0f3ad2
This commit is contained in:
Andy Ly 2021-01-25 13:26:13 -08:00 committed by TensorFlower Gardener
parent d66729431d
commit 1071665bc3
7 changed files with 2996 additions and 1113 deletions

View File

@ -520,6 +520,7 @@ cc_library(
":hlo_graph_dumper",
":hlo_pass",
":hlo_sharding_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -541,12 +542,16 @@ tf_cc_test(
"sharding_propagation_test.cc",
],
deps = [
"hlo_matchers",
":hlo",
":hlo_matchers",
":hlo_parser",
":sharding_propagation",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -285,7 +285,8 @@ class HloSharding {
int64 NumTiles(absl::Span<const int64> dims) const;
// Gets metadata from sharding.
absl::Span<const OpMetadata> metadata() const { return metadata_; }
std::vector<OpMetadata>& metadata() { return metadata_; }
const std::vector<OpMetadata>& metadata() const { return metadata_; }
private:
explicit HloSharding(bool manual, bool replicated,

View File

@ -440,7 +440,7 @@ TEST_F(HloShardingTest, WithMetadataNoOverwrite) {
HloSharding sharding = HloSharding::Replicate();
auto sharding_new_metadata =
sharding.WithMetadata(SingleMetadata(), /*overwrite=*/false);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 1);
ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding_new_metadata.metadata().front(), SingleMetadata().front()));
}
@ -449,7 +449,7 @@ TEST_F(HloShardingTest, WithMetadataNoOverwrite) {
HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
auto sharding_new_metadata =
sharding.WithMetadata(ListMetadata(), /*overwrite=*/false);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 1);
ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding.metadata().front(), sharding_new_metadata.metadata().front()));
}
@ -492,7 +492,7 @@ TEST_F(HloShardingTest, WithMetadataOverwrite) {
HloSharding sharding = HloSharding::Replicate();
auto sharding_new_metadata =
sharding.WithMetadata(SingleMetadata(), /*overwrite=*/true);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 1);
ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding_new_metadata.metadata().front(), SingleMetadata().front()));
}
@ -501,7 +501,7 @@ TEST_F(HloShardingTest, WithMetadataOverwrite) {
HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
auto sharding_new_metadata =
sharding.WithMetadata(ListMetadata(), /*overwrite=*/true);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 2);
ASSERT_EQ(sharding_new_metadata.metadata().size(), 2);
for (int i = 0; i < 2; ++i) {
EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding_new_metadata.metadata()[i], ListMetadata()[i]));

View File

@ -130,8 +130,8 @@ HloSharding TransposeSharding(const HloSharding& sharding,
*value = sharding.tile_assignment()(src_indices);
});
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
? HloSharding::PartialTile(tile_assignment, sharding.metadata())
: HloSharding::Tile(tile_assignment, sharding.metadata());
}
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
@ -244,8 +244,9 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
}
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
sharding.metadata())
: HloSharding::Tile(new_tile_assignment, sharding.metadata());
}
HloSharding ReverseSharding(const HloSharding& sharding,
@ -264,8 +265,9 @@ HloSharding ReverseSharding(const HloSharding& sharding,
*device = sharding.tile_assignment()(original_indices);
});
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
sharding.metadata())
: HloSharding::Tile(new_tile_assignment, sharding.metadata());
}
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
@ -319,7 +321,7 @@ HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
tile_dims[dim] = devices.size() / ignore_size;
Array<int64> tile_assignment(tile_dims);
tile_assignment.SetValues(devices);
return HloSharding::Tile(tile_assignment);
return HloSharding::Tile(tile_assignment, sharding.metadata());
}
bool ContainsTileSharding(const HloModule& module) {
@ -362,12 +364,14 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(output_tile_assignment_dims)) {
return HloSharding::Replicate();
return HloSharding::Replicate(index_sharding.metadata());
}
new_tile_assignment.Reshape(output_tile_assignment_dims);
return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
index_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
index_sharding.metadata());
}
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
@ -401,12 +405,14 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
return HloSharding::Replicate(output_sharding.metadata());
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
output_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
output_sharding.metadata());
}
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
@ -435,7 +441,8 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
// Output sharding is only on offset dimensions. We do not shard this gather
// op. Return a tile maximal sharding with the first device in output
// sharding tile assignment.
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin());
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
hlo.sharding().metadata());
}
// Output sharding is on both offset and non offset dimensions. We shard the
@ -456,7 +463,7 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
}
Array<int64> tile_assignment =
hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
}
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
@ -483,12 +490,13 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
return HloSharding::Replicate(data_sharding.metadata());
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return data_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
data_sharding.metadata())
: HloSharding::Tile(new_tile_assignment, data_sharding.metadata());
}
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
@ -515,12 +523,14 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding,
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(data_tile_assignment_dims)) {
return HloSharding::Replicate();
return HloSharding::Replicate(index_sharding.metadata());
}
new_tile_assignment.Reshape(data_tile_assignment_dims);
return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
index_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
index_sharding.metadata());
}
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
@ -549,7 +559,8 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
// op. Return a tile maximal sharding with the first device in index sharding
// tile assignment.
if (num_elements == 1) {
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin());
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
index_sharding.metadata());
}
const int64 index_rank = hlo.operand(1)->shape().rank();
@ -563,7 +574,7 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
}
Array<int64> tile_assignment =
index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
return HloSharding::Tile(tile_assignment, index_sharding.metadata());
}
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
@ -593,7 +604,8 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
// Data sharding is only on update_window_dims. We do not shard this
// scatter op. Return a tile maximal sharding with the first device in
// data sharding tile assignment.
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin());
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
data_sharding.metadata());
}
// Data sharding is on both update_window_dims and scatter_window_dims. We
@ -605,7 +617,7 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
std::vector<int64> slice_starts(data_rank, 0LL);
Array<int64> tile_assignment =
data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
return HloSharding::Tile(tile_assignment);
return HloSharding::Tile(tile_assignment, data_sharding.metadata());
}
namespace {
@ -654,8 +666,9 @@ absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
Array<int64> tile_assignment = operand_sharding.tile_assignment();
tile_assignment.Reshape(passthrough_tile);
return operand_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
? HloSharding::PartialTile(tile_assignment,
operand_sharding.metadata())
: HloSharding::Tile(tile_assignment, operand_sharding.metadata());
}
// Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
@ -700,8 +713,10 @@ absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
}
tile_assignment.Reshape(passthrough_tile);
return update_or_gather_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
? HloSharding::PartialTile(tile_assignment,
update_or_gather_sharding.metadata())
: HloSharding::Tile(tile_assignment,
update_or_gather_sharding.metadata());
}
// Collect data operand sharding for a gather with parallel dimensions from
@ -748,8 +763,9 @@ absl::optional<HloSharding> GatherParallelDataOperandSharding(
}
tile_assignment.Reshape(operand_tile_assignment);
return output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
? HloSharding::PartialTile(tile_assignment,
output_sharding.metadata())
: HloSharding::Tile(tile_assignment, output_sharding.metadata());
}
} // namespace
@ -940,7 +956,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims(
return sharding;
}
if (group_count == sharding.NumTiles()) {
return HloSharding::Replicate();
return HloSharding::Replicate(sharding.metadata());
}
std::vector<int64> dim_permutation(
sharding.tile_assignment().num_dimensions());
@ -963,7 +979,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims(
new_tile_shape.push_back(group_count);
}
new_tile.Reshape(new_tile_shape);
return HloSharding::PartialTile(new_tile);
return HloSharding::PartialTile(new_tile, sharding.metadata());
}
HloSharding RemoveShapeDimensions(const HloSharding& sharding,
@ -983,8 +999,9 @@ HloSharding RemoveShapeDimensions(const HloSharding& sharding,
}
auto new_tile = sharding.tile_assignment();
new_tile.Reshape(new_tile_shape);
return sharding.ReplicateOnLastTileDim() ? HloSharding::PartialTile(new_tile)
: HloSharding::Tile(new_tile);
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile, sharding.metadata())
: HloSharding::Tile(new_tile, sharding.metadata());
}
absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
@ -1034,8 +1051,8 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
}
reshape_tiles.Reshape(tgt_tiles);
return source.ReplicateOnLastTileDim()
? HloSharding::PartialTile(reshape_tiles)
: HloSharding::Tile(reshape_tiles);
? HloSharding::PartialTile(reshape_tiles, source.metadata())
: HloSharding::Tile(reshape_tiles, source.metadata());
}
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@ -211,12 +212,18 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
new_group_members[new_group_id].erase(*device);
});
if (compatible) {
std::vector<OpMetadata> merged_metadata;
std::swap(merged_metadata, to_merge->metadata());
merged_metadata.reserve(to_merge->metadata().size() +
old.metadata().size());
merged_metadata.insert(merged_metadata.end(), old.metadata().begin(),
old.metadata().end());
if (replication == 1) {
new_tile_dims.pop_back();
new_tile.Reshape(new_tile_dims);
*to_merge = HloSharding::Tile(new_tile);
*to_merge = HloSharding::Tile(new_tile, merged_metadata);
} else {
*to_merge = HloSharding::PartialTile(new_tile);
*to_merge = HloSharding::PartialTile(new_tile, merged_metadata);
}
return true;
}
@ -600,8 +607,11 @@ bool InferGatherParallelShardingFromOperands(
auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment();
output_tile_assignment.Reshape(output_tile_dims);
return replicate_non_parallel_dims.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment)
: HloSharding::Tile(output_tile_assignment);
? HloSharding::PartialTile(
output_tile_assignment,
replicate_non_parallel_dims.metadata())
: HloSharding::Tile(output_tile_assignment,
replicate_non_parallel_dims.metadata());
};
bool changed = false;
@ -689,7 +699,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
}
if (lhs->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(
HloSharding::Replicate(), instruction, may_combine_partial_sharding);
HloSharding::Replicate(lhs->sharding().metadata()), instruction,
may_combine_partial_sharding);
}
if (IsConvolutionKernelSmall(instruction)) {
@ -744,12 +755,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
(instruction->shape().IsArray() ||
instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kSort ||
instruction->opcode() == HloOpcode::kReduceWindow) &&
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
return op->has_sharding() && op->sharding().IsManual();
})) {
instruction->set_sharding(HloSharding::Manual());
return true;
instruction->opcode() == HloOpcode::kReduceWindow)) {
for (const HloInstruction* op : instruction->operands()) {
if (!op->has_sharding() || !op->sharding().IsManual()) continue;
instruction->set_sharding(HloSharding::Manual(op->sharding().metadata()));
return true;
}
}
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
@ -760,11 +771,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
instruction->HasSideEffect()) {
return false;
}
if (absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
return op->has_sharding() && op->sharding().IsReplicated();
})) {
return MaybeImproveInstructionSharding(
HloSharding::Replicate(), instruction, may_combine_partial_sharding);
for (const HloInstruction* op : instruction->operands()) {
if (op->has_sharding() && op->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(
HloSharding::Replicate(op->sharding().metadata()), instruction,
may_combine_partial_sharding);
}
}
return false;
}
@ -827,7 +839,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings);
if (new_sharding != instruction->sharding()) {
instruction->set_sharding(new_sharding);
instruction->set_sharding(std::move(new_sharding));
return true;
}
return changed;
@ -859,8 +871,9 @@ bool InferShardingFromOperands(HloInstruction* instruction,
// We are reducing along one of the sharded dimensions. We only
// support this in SPMD.
changed |= MaybeImproveInstructionSharding(
get_maybe_tuple_sharding(HloSharding::Replicate()), instruction,
may_combine_partial_sharding);
get_maybe_tuple_sharding(
HloSharding::Replicate(operand->sharding().metadata())),
instruction, may_combine_partial_sharding);
continue;
}
auto after_partial_replication =
@ -870,7 +883,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
operand->sharding(), instruction->dimensions());
if (after_partial_replication.IsReplicated()) {
changed |= MaybeImproveInstructionSharding(
get_maybe_tuple_sharding(HloSharding::Replicate()), instruction,
get_maybe_tuple_sharding(after_partial_replication), instruction,
may_combine_partial_sharding);
continue;
}
@ -922,8 +935,10 @@ bool InferShardingFromOperands(HloInstruction* instruction,
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
HloSharding new_sharding =
op->sharding().ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
? HloSharding::PartialTile(new_tile_assignment,
op->sharding().metadata())
: HloSharding::Tile(new_tile_assignment,
op->sharding().metadata());
return MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
@ -1071,9 +1086,9 @@ bool InferShardingFromOperands(HloInstruction* instruction,
}
if (operand->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
instruction,
may_combine_partial_sharding);
return MaybeImproveInstructionSharding(
HloSharding::Replicate(operand->sharding().metadata()),
instruction, may_combine_partial_sharding);
}
const auto& tile_assignment = operand->sharding().tile_assignment();
@ -1275,6 +1290,7 @@ absl::optional<HloSharding> GetShardingFromUser(
return absl::nullopt;
}
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
switch (user.opcode()) {
case HloOpcode::kBroadcast: {
if (user.sharding().IsReplicated()) {
@ -1338,9 +1354,10 @@ absl::optional<HloSharding> GetShardingFromUser(
auto new_tile_assignment =
tile_assignment.Slice(start_indices, end_indices);
if (new_tile_assignment.num_elements() == 1) {
return HloSharding::AssignDevice(*new_tile_assignment.begin());
return HloSharding::AssignDevice(*new_tile_assignment.begin(),
user.sharding().metadata());
}
return HloSharding::Tile(new_tile_assignment);
return HloSharding::Tile(new_tile_assignment, user.sharding().metadata());
}
case HloOpcode::kConvolution: {
auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
@ -1378,8 +1395,9 @@ absl::optional<HloSharding> GetShardingFromUser(
}
case HloOpcode::kReduceWindow: {
if (user.shape().IsTuple()) {
return user.sharding().GetSubSharding(
auto sub_sharding = user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
return sub_sharding;
}
if (&instruction != user.operand(0)) {
return absl::nullopt;
@ -1411,8 +1429,9 @@ absl::optional<HloSharding> GetShardingFromUser(
reverse_dimensions);
}
case HloOpcode::kTuple: {
return user.sharding().GetSubSharding(user.shape(),
{user.operand_index(&instruction)});
auto sub_sharding = user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
return sub_sharding;
}
case HloOpcode::kGetTupleElement: {
HloSharding new_sharding =
@ -1475,16 +1494,17 @@ absl::optional<HloSharding> GetShardingFromUser(
auto tile_assignment = user_sharding.tile_assignment();
tile_assignment.Reshape(target_tile_assignment_dimensions);
return user_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
? HloSharding::PartialTile(tile_assignment,
user_sharding.metadata())
: HloSharding::Tile(tile_assignment, user_sharding.metadata());
}
case HloOpcode::kSort: {
if (user.sharding().IsTuple()) {
return user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
} else {
return user.sharding();
HloSharding user_sharding = user.sharding();
if (user_sharding.IsTuple()) {
return user_sharding = user_sharding.GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
}
return user_sharding;
}
case HloOpcode::kReverse: {
return hlo_sharding_util::ReverseSharding(user.sharding(),
@ -1555,13 +1575,15 @@ bool InferShardingFromUsers(HloInstruction* instruction,
return false;
}
// Propagate manual sharding.
if (!instruction->has_sharding() && instruction->shape().IsArray() &&
absl::c_any_of(instruction->users(), [](const HloInstruction* user) {
return user->has_sharding() && user->sharding().IsManual() &&
!user->IsCustomCall("SPMDFullToShardShape");
})) {
instruction->set_sharding(HloSharding::Manual());
return true;
if (!instruction->has_sharding() && instruction->shape().IsArray()) {
for (const HloInstruction* user : instruction->users()) {
if (!user->has_sharding() || !user->sharding().IsManual() ||
user->IsCustomCall("SPMDFullToShardShape"))
continue;
instruction->set_sharding(
HloSharding::Manual(user->sharding().metadata()));
return true;
}
}
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
return false;
@ -1579,8 +1601,80 @@ bool InferShardingFromUsers(HloInstruction* instruction,
return improved_sharding;
}
// Checks if two HloShardings have the same metadata attached.
bool SameShardingMetadata(const HloSharding& a, const HloSharding& b) {
DCHECK_EQ(a, b);
auto same_metadata = [](absl::Span<const OpMetadata> a,
absl::Span<const OpMetadata> b) {
if (a.size() != b.size()) return false;
for (int i = 0, e = a.size(); i < e; ++i) {
if (!protobuf_util::ProtobufEquals(a[i], b[i])) {
return false;
}
}
return true;
};
if (a.IsTuple()) {
for (int i = 0, e = a.tuple_elements().size(); i < e; ++i) {
if (!same_metadata(a.tuple_elements()[i].metadata(),
b.tuple_elements()[i].metadata())) {
return false;
}
}
return true;
} else {
return same_metadata(a.metadata(), b.metadata());
}
}
// Assigns metadata to optional sharding on instructions if instructions have
// metadata. If sharding already has some metadata, no new metadata will be
// added.
bool AssignShardingMetadata(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
const auto& metadata = instruction->metadata();
if (!instruction->has_sharding() || metadata.ByteSizeLong() == 0) {
continue;
}
HloSharding sharding_with_metadata =
instruction->sharding().WithMetadata({metadata}, /*overwrite=*/false);
if (!SameShardingMetadata(instruction->sharding(),
sharding_with_metadata)) {
instruction->set_sharding(std::move(sharding_with_metadata));
changed = true;
}
}
}
return changed;
}
// Removes all sharding metadata from shardings on instructions.
bool RemoveShardingMetadata(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (!instruction->has_sharding()) {
continue;
}
HloSharding sharding_no_metadata =
instruction->sharding().WithoutMetadata();
if (!SameShardingMetadata(instruction->sharding(),
sharding_no_metadata)) {
instruction->set_sharding(std::move(sharding_no_metadata));
changed = true;
}
}
}
return changed;
}
// Remove Sharding custom-call instruction by folding the sharding attribute
// to its operand. If the operand alreayd has a different sharding, insert a
// to its operand. If the operand already has a different sharding, insert a
// copy node for reshard.
StatusOr<bool> ProcessShardingInstruction(HloModule* module) {
bool changed = false;
@ -1721,7 +1815,12 @@ Status CheckAndUpdateDeviceAssignmentsInWhileBody(
}
StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool any_changed, ProcessShardingInstruction(module));
bool any_changed = propagate_metadata_ ? AssignShardingMetadata(module)
: RemoveShardingMetadata(module);
auto status_or_changed = ProcessShardingInstruction(module);
if (!status_or_changed.ok()) return status_or_changed;
any_changed |= status_or_changed.ValueOrDie();
// Association of partitionable embedded computations with their parent
// instruction.

View File

@ -30,7 +30,9 @@ namespace xla {
// a simple local greedy heuristic.
class ShardingPropagation : public HloModulePass {
public:
explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {}
explicit ShardingPropagation(bool is_spmd = false,
bool propagate_metadata = false)
: is_spmd_(is_spmd), propagate_metadata_(propagate_metadata) {}
absl::string_view name() const override { return "sharding-propagation"; }
StatusOr<bool> Run(HloModule* module) override;
@ -43,6 +45,7 @@ class ShardingPropagation : public HloModulePass {
private:
bool is_spmd_;
bool propagate_metadata_;
};
} // namespace xla

File diff suppressed because it is too large Load Diff