[XLA] Add option to propagate sharding metadata in Hlo sharding propagation pass.

Sharding metadata is retained if metadata propagation is not allowed. Metadata is also assigned to sharding (with no metadata) attached to instructions using the instruction's metadata if available.

Sharding propagation tests are updated to be parameterized, to test sharding metadata propagation.

PiperOrigin-RevId: 353716863
Change-Id: I55e79e6dec7eea5a2cd7ca718d861e17ba0f3ad2
This commit is contained in:
Andy Ly 2021-01-25 13:26:13 -08:00 committed by TensorFlower Gardener
parent d66729431d
commit 1071665bc3
7 changed files with 2996 additions and 1113 deletions

View File

@ -520,6 +520,7 @@ cc_library(
":hlo_graph_dumper", ":hlo_graph_dumper",
":hlo_pass", ":hlo_pass",
":hlo_sharding_util", ":hlo_sharding_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
@ -541,12 +542,16 @@ tf_cc_test(
"sharding_propagation_test.cc", "sharding_propagation_test.cc",
], ],
deps = [ deps = [
"hlo_matchers", ":hlo",
":hlo_matchers",
":hlo_parser", ":hlo_parser",
":sharding_propagation", ":sharding_propagation",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -285,7 +285,8 @@ class HloSharding {
int64 NumTiles(absl::Span<const int64> dims) const; int64 NumTiles(absl::Span<const int64> dims) const;
// Gets metadata from sharding. // Gets metadata from sharding.
absl::Span<const OpMetadata> metadata() const { return metadata_; } std::vector<OpMetadata>& metadata() { return metadata_; }
const std::vector<OpMetadata>& metadata() const { return metadata_; }
private: private:
explicit HloSharding(bool manual, bool replicated, explicit HloSharding(bool manual, bool replicated,

View File

@ -440,7 +440,7 @@ TEST_F(HloShardingTest, WithMetadataNoOverwrite) {
HloSharding sharding = HloSharding::Replicate(); HloSharding sharding = HloSharding::Replicate();
auto sharding_new_metadata = auto sharding_new_metadata =
sharding.WithMetadata(SingleMetadata(), /*overwrite=*/false); sharding.WithMetadata(SingleMetadata(), /*overwrite=*/false);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 1); ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
EXPECT_TRUE(protobuf_util::ProtobufEquals( EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding_new_metadata.metadata().front(), SingleMetadata().front())); sharding_new_metadata.metadata().front(), SingleMetadata().front()));
} }
@ -449,7 +449,7 @@ TEST_F(HloShardingTest, WithMetadataNoOverwrite) {
HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata()); HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
auto sharding_new_metadata = auto sharding_new_metadata =
sharding.WithMetadata(ListMetadata(), /*overwrite=*/false); sharding.WithMetadata(ListMetadata(), /*overwrite=*/false);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 1); ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
EXPECT_TRUE(protobuf_util::ProtobufEquals( EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding.metadata().front(), sharding_new_metadata.metadata().front())); sharding.metadata().front(), sharding_new_metadata.metadata().front()));
} }
@ -492,7 +492,7 @@ TEST_F(HloShardingTest, WithMetadataOverwrite) {
HloSharding sharding = HloSharding::Replicate(); HloSharding sharding = HloSharding::Replicate();
auto sharding_new_metadata = auto sharding_new_metadata =
sharding.WithMetadata(SingleMetadata(), /*overwrite=*/true); sharding.WithMetadata(SingleMetadata(), /*overwrite=*/true);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 1); ASSERT_EQ(sharding_new_metadata.metadata().size(), 1);
EXPECT_TRUE(protobuf_util::ProtobufEquals( EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding_new_metadata.metadata().front(), SingleMetadata().front())); sharding_new_metadata.metadata().front(), SingleMetadata().front()));
} }
@ -501,7 +501,7 @@ TEST_F(HloShardingTest, WithMetadataOverwrite) {
HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata()); HloSharding sharding = HloSharding::AssignDevice(7, SingleMetadata());
auto sharding_new_metadata = auto sharding_new_metadata =
sharding.WithMetadata(ListMetadata(), /*overwrite=*/true); sharding.WithMetadata(ListMetadata(), /*overwrite=*/true);
ASSERT_EQ(sharding_new_metadata.metadata().length(), 2); ASSERT_EQ(sharding_new_metadata.metadata().size(), 2);
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
EXPECT_TRUE(protobuf_util::ProtobufEquals( EXPECT_TRUE(protobuf_util::ProtobufEquals(
sharding_new_metadata.metadata()[i], ListMetadata()[i])); sharding_new_metadata.metadata()[i], ListMetadata()[i]));

View File

@ -130,8 +130,8 @@ HloSharding TransposeSharding(const HloSharding& sharding,
*value = sharding.tile_assignment()(src_indices); *value = sharding.tile_assignment()(src_indices);
}); });
return sharding.ReplicateOnLastTileDim() return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment) ? HloSharding::PartialTile(tile_assignment, sharding.metadata())
: HloSharding::Tile(tile_assignment); : HloSharding::Tile(tile_assignment, sharding.metadata());
} }
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape, absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
@ -244,8 +244,9 @@ absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
} }
new_tile_assignment.Reshape(target_tile_assignment_dimensions); new_tile_assignment.Reshape(target_tile_assignment_dimensions);
return sharding.ReplicateOnLastTileDim() return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); sharding.metadata())
: HloSharding::Tile(new_tile_assignment, sharding.metadata());
} }
HloSharding ReverseSharding(const HloSharding& sharding, HloSharding ReverseSharding(const HloSharding& sharding,
@ -264,8 +265,9 @@ HloSharding ReverseSharding(const HloSharding& sharding,
*device = sharding.tile_assignment()(original_indices); *device = sharding.tile_assignment()(original_indices);
}); });
return sharding.ReplicateOnLastTileDim() return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); sharding.metadata())
: HloSharding::Tile(new_tile_assignment, sharding.metadata());
} }
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
@ -319,7 +321,7 @@ HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
tile_dims[dim] = devices.size() / ignore_size; tile_dims[dim] = devices.size() / ignore_size;
Array<int64> tile_assignment(tile_dims); Array<int64> tile_assignment(tile_dims);
tile_assignment.SetValues(devices); tile_assignment.SetValues(devices);
return HloSharding::Tile(tile_assignment); return HloSharding::Tile(tile_assignment, sharding.metadata());
} }
bool ContainsTileSharding(const HloModule& module) { bool ContainsTileSharding(const HloModule& module) {
@ -362,12 +364,14 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
Array<int64> new_tile_assignment = index_sharding.tile_assignment(); Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() != if (new_tile_assignment.num_elements() !=
Product(output_tile_assignment_dims)) { Product(output_tile_assignment_dims)) {
return HloSharding::Replicate(); return HloSharding::Replicate(index_sharding.metadata());
} }
new_tile_assignment.Reshape(output_tile_assignment_dims); new_tile_assignment.Reshape(output_tile_assignment_dims);
return index_sharding.ReplicateOnLastTileDim() return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); index_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
index_sharding.metadata());
} }
HloSharding GatherIndexSharding(const HloSharding& output_sharding, HloSharding GatherIndexSharding(const HloSharding& output_sharding,
@ -401,12 +405,14 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
Array<int64> new_tile_assignment = output_sharding.tile_assignment(); Array<int64> new_tile_assignment = output_sharding.tile_assignment();
if (new_tile_assignment.num_elements() != if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) { Product(index_tile_assignment_dims)) {
return HloSharding::Replicate(); return HloSharding::Replicate(output_sharding.metadata());
} }
new_tile_assignment.Reshape(index_tile_assignment_dims); new_tile_assignment.Reshape(index_tile_assignment_dims);
return output_sharding.ReplicateOnLastTileDim() return output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); output_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
output_sharding.metadata());
} }
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
@ -435,7 +441,8 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
// Output sharding is only on offset dimensions. We do not shard this gather // Output sharding is only on offset dimensions. We do not shard this gather
// op. Return a tile maximal sharding with the first device in output // op. Return a tile maximal sharding with the first device in output
// sharding tile assignment. // sharding tile assignment.
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin()); return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
hlo.sharding().metadata());
} }
// Output sharding is on both offset and non offset dimensions. We shard the // Output sharding is on both offset and non offset dimensions. We shard the
@ -456,7 +463,7 @@ HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
} }
Array<int64> tile_assignment = Array<int64> tile_assignment =
hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits); hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment); return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
} }
HloSharding ScatterIndexSharding(const HloSharding& data_sharding, HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
@ -483,12 +490,13 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
Array<int64> new_tile_assignment = data_sharding.tile_assignment(); Array<int64> new_tile_assignment = data_sharding.tile_assignment();
if (new_tile_assignment.num_elements() != if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) { Product(index_tile_assignment_dims)) {
return HloSharding::Replicate(); return HloSharding::Replicate(data_sharding.metadata());
} }
new_tile_assignment.Reshape(index_tile_assignment_dims); new_tile_assignment.Reshape(index_tile_assignment_dims);
return data_sharding.ReplicateOnLastTileDim() return data_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); data_sharding.metadata())
: HloSharding::Tile(new_tile_assignment, data_sharding.metadata());
} }
HloSharding ScatterDataSharding(const HloSharding& index_sharding, HloSharding ScatterDataSharding(const HloSharding& index_sharding,
@ -515,12 +523,14 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding,
Array<int64> new_tile_assignment = index_sharding.tile_assignment(); Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() != if (new_tile_assignment.num_elements() !=
Product(data_tile_assignment_dims)) { Product(data_tile_assignment_dims)) {
return HloSharding::Replicate(); return HloSharding::Replicate(index_sharding.metadata());
} }
new_tile_assignment.Reshape(data_tile_assignment_dims); new_tile_assignment.Reshape(data_tile_assignment_dims);
return index_sharding.ReplicateOnLastTileDim() return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); index_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
index_sharding.metadata());
} }
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
@ -549,7 +559,8 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
// op. Return a tile maximal sharding with the first device in index sharding // op. Return a tile maximal sharding with the first device in index sharding
// tile assignment. // tile assignment.
if (num_elements == 1) { if (num_elements == 1) {
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin()); return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
index_sharding.metadata());
} }
const int64 index_rank = hlo.operand(1)->shape().rank(); const int64 index_rank = hlo.operand(1)->shape().rank();
@ -563,7 +574,7 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
} }
Array<int64> tile_assignment = Array<int64> tile_assignment =
index_sharding.tile_assignment().Slice(slice_starts, slice_limits); index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment); return HloSharding::Tile(tile_assignment, index_sharding.metadata());
} }
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
@ -593,7 +604,8 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
// Data sharding is only on update_window_dims. We do not shard this // Data sharding is only on update_window_dims. We do not shard this
// scatter op. Return a tile maximal sharding with the first device in // scatter op. Return a tile maximal sharding with the first device in
// data sharding tile assignment. // data sharding tile assignment.
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin()); return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
data_sharding.metadata());
} }
// Data sharding is on both update_window_dims and scatter_window_dims. We // Data sharding is on both update_window_dims and scatter_window_dims. We
@ -605,7 +617,7 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
std::vector<int64> slice_starts(data_rank, 0LL); std::vector<int64> slice_starts(data_rank, 0LL);
Array<int64> tile_assignment = Array<int64> tile_assignment =
data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
return HloSharding::Tile(tile_assignment); return HloSharding::Tile(tile_assignment, data_sharding.metadata());
} }
namespace { namespace {
@ -654,8 +666,9 @@ absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
Array<int64> tile_assignment = operand_sharding.tile_assignment(); Array<int64> tile_assignment = operand_sharding.tile_assignment();
tile_assignment.Reshape(passthrough_tile); tile_assignment.Reshape(passthrough_tile);
return operand_sharding.ReplicateOnLastTileDim() return operand_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment) ? HloSharding::PartialTile(tile_assignment,
: HloSharding::Tile(tile_assignment); operand_sharding.metadata())
: HloSharding::Tile(tile_assignment, operand_sharding.metadata());
} }
// Inverse of PassthroughOperandToGatherOutputOrScatterUpdate. // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
@ -700,8 +713,10 @@ absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
} }
tile_assignment.Reshape(passthrough_tile); tile_assignment.Reshape(passthrough_tile);
return update_or_gather_sharding.ReplicateOnLastTileDim() return update_or_gather_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment) ? HloSharding::PartialTile(tile_assignment,
: HloSharding::Tile(tile_assignment); update_or_gather_sharding.metadata())
: HloSharding::Tile(tile_assignment,
update_or_gather_sharding.metadata());
} }
// Collect data operand sharding for a gather with parallel dimensions from // Collect data operand sharding for a gather with parallel dimensions from
@ -748,8 +763,9 @@ absl::optional<HloSharding> GatherParallelDataOperandSharding(
} }
tile_assignment.Reshape(operand_tile_assignment); tile_assignment.Reshape(operand_tile_assignment);
return output_sharding.ReplicateOnLastTileDim() return output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment) ? HloSharding::PartialTile(tile_assignment,
: HloSharding::Tile(tile_assignment); output_sharding.metadata())
: HloSharding::Tile(tile_assignment, output_sharding.metadata());
} }
} // namespace } // namespace
@ -940,7 +956,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims(
return sharding; return sharding;
} }
if (group_count == sharding.NumTiles()) { if (group_count == sharding.NumTiles()) {
return HloSharding::Replicate(); return HloSharding::Replicate(sharding.metadata());
} }
std::vector<int64> dim_permutation( std::vector<int64> dim_permutation(
sharding.tile_assignment().num_dimensions()); sharding.tile_assignment().num_dimensions());
@ -963,7 +979,7 @@ HloSharding PartiallyReplicateTiledShardingOnDims(
new_tile_shape.push_back(group_count); new_tile_shape.push_back(group_count);
} }
new_tile.Reshape(new_tile_shape); new_tile.Reshape(new_tile_shape);
return HloSharding::PartialTile(new_tile); return HloSharding::PartialTile(new_tile, sharding.metadata());
} }
HloSharding RemoveShapeDimensions(const HloSharding& sharding, HloSharding RemoveShapeDimensions(const HloSharding& sharding,
@ -983,8 +999,9 @@ HloSharding RemoveShapeDimensions(const HloSharding& sharding,
} }
auto new_tile = sharding.tile_assignment(); auto new_tile = sharding.tile_assignment();
new_tile.Reshape(new_tile_shape); new_tile.Reshape(new_tile_shape);
return sharding.ReplicateOnLastTileDim() ? HloSharding::PartialTile(new_tile) return sharding.ReplicateOnLastTileDim()
: HloSharding::Tile(new_tile); ? HloSharding::PartialTile(new_tile, sharding.metadata())
: HloSharding::Tile(new_tile, sharding.metadata());
} }
absl::optional<HloSharding> TransposeShardingWithCollapsedDims( absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
@ -1034,8 +1051,8 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
} }
reshape_tiles.Reshape(tgt_tiles); reshape_tiles.Reshape(tgt_tiles);
return source.ReplicateOnLastTileDim() return source.ReplicateOnLastTileDim()
? HloSharding::PartialTile(reshape_tiles) ? HloSharding::PartialTile(reshape_tiles, source.metadata())
: HloSharding::Tile(reshape_tiles); : HloSharding::Tile(reshape_tiles, source.metadata());
} }
absl::optional<GatherParallelDims> GetGatherBatchParallelDims( absl::optional<GatherParallelDims> GetGatherBatchParallelDims(

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@ -211,12 +212,18 @@ bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
new_group_members[new_group_id].erase(*device); new_group_members[new_group_id].erase(*device);
}); });
if (compatible) { if (compatible) {
std::vector<OpMetadata> merged_metadata;
std::swap(merged_metadata, to_merge->metadata());
merged_metadata.reserve(to_merge->metadata().size() +
old.metadata().size());
merged_metadata.insert(merged_metadata.end(), old.metadata().begin(),
old.metadata().end());
if (replication == 1) { if (replication == 1) {
new_tile_dims.pop_back(); new_tile_dims.pop_back();
new_tile.Reshape(new_tile_dims); new_tile.Reshape(new_tile_dims);
*to_merge = HloSharding::Tile(new_tile); *to_merge = HloSharding::Tile(new_tile, merged_metadata);
} else { } else {
*to_merge = HloSharding::PartialTile(new_tile); *to_merge = HloSharding::PartialTile(new_tile, merged_metadata);
} }
return true; return true;
} }
@ -600,8 +607,11 @@ bool InferGatherParallelShardingFromOperands(
auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment(); auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment();
output_tile_assignment.Reshape(output_tile_dims); output_tile_assignment.Reshape(output_tile_dims);
return replicate_non_parallel_dims.ReplicateOnLastTileDim() return replicate_non_parallel_dims.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment) ? HloSharding::PartialTile(
: HloSharding::Tile(output_tile_assignment); output_tile_assignment,
replicate_non_parallel_dims.metadata())
: HloSharding::Tile(output_tile_assignment,
replicate_non_parallel_dims.metadata());
}; };
bool changed = false; bool changed = false;
@ -689,7 +699,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
} }
if (lhs->sharding().IsReplicated()) { if (lhs->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding( return MaybeImproveInstructionSharding(
HloSharding::Replicate(), instruction, may_combine_partial_sharding); HloSharding::Replicate(lhs->sharding().metadata()), instruction,
may_combine_partial_sharding);
} }
if (IsConvolutionKernelSmall(instruction)) { if (IsConvolutionKernelSmall(instruction)) {
@ -744,13 +755,13 @@ bool InferShardingFromOperands(HloInstruction* instruction,
(instruction->shape().IsArray() || (instruction->shape().IsArray() ||
instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kSort || instruction->opcode() == HloOpcode::kSort ||
instruction->opcode() == HloOpcode::kReduceWindow) && instruction->opcode() == HloOpcode::kReduceWindow)) {
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { for (const HloInstruction* op : instruction->operands()) {
return op->has_sharding() && op->sharding().IsManual(); if (!op->has_sharding() || !op->sharding().IsManual()) continue;
})) { instruction->set_sharding(HloSharding::Manual(op->sharding().metadata()));
instruction->set_sharding(HloSharding::Manual());
return true; return true;
} }
}
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
// If an array shaped HLO doesn't support spatial partitioning but at least // If an array shaped HLO doesn't support spatial partitioning but at least
@ -760,11 +771,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
instruction->HasSideEffect()) { instruction->HasSideEffect()) {
return false; return false;
} }
if (absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { for (const HloInstruction* op : instruction->operands()) {
return op->has_sharding() && op->sharding().IsReplicated(); if (op->has_sharding() && op->sharding().IsReplicated()) {
})) {
return MaybeImproveInstructionSharding( return MaybeImproveInstructionSharding(
HloSharding::Replicate(), instruction, may_combine_partial_sharding); HloSharding::Replicate(op->sharding().metadata()), instruction,
may_combine_partial_sharding);
}
} }
return false; return false;
} }
@ -827,7 +839,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings); HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings);
if (new_sharding != instruction->sharding()) { if (new_sharding != instruction->sharding()) {
instruction->set_sharding(new_sharding); instruction->set_sharding(std::move(new_sharding));
return true; return true;
} }
return changed; return changed;
@ -859,8 +871,9 @@ bool InferShardingFromOperands(HloInstruction* instruction,
// We are reducing along one of the sharded dimensions. We only // We are reducing along one of the sharded dimensions. We only
// support this in SPMD. // support this in SPMD.
changed |= MaybeImproveInstructionSharding( changed |= MaybeImproveInstructionSharding(
get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, get_maybe_tuple_sharding(
may_combine_partial_sharding); HloSharding::Replicate(operand->sharding().metadata())),
instruction, may_combine_partial_sharding);
continue; continue;
} }
auto after_partial_replication = auto after_partial_replication =
@ -870,7 +883,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
operand->sharding(), instruction->dimensions()); operand->sharding(), instruction->dimensions());
if (after_partial_replication.IsReplicated()) { if (after_partial_replication.IsReplicated()) {
changed |= MaybeImproveInstructionSharding( changed |= MaybeImproveInstructionSharding(
get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, get_maybe_tuple_sharding(after_partial_replication), instruction,
may_combine_partial_sharding); may_combine_partial_sharding);
continue; continue;
} }
@ -922,8 +935,10 @@ bool InferShardingFromOperands(HloInstruction* instruction,
new_tile_assignment.Reshape(target_tile_assignment_dimensions); new_tile_assignment.Reshape(target_tile_assignment_dimensions);
HloSharding new_sharding = HloSharding new_sharding =
op->sharding().ReplicateOnLastTileDim() op->sharding().ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment) ? HloSharding::PartialTile(new_tile_assignment,
: HloSharding::Tile(new_tile_assignment); op->sharding().metadata())
: HloSharding::Tile(new_tile_assignment,
op->sharding().metadata());
return MaybeImproveInstructionSharding( return MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding); std::move(new_sharding), instruction, may_combine_partial_sharding);
} }
@ -1071,9 +1086,9 @@ bool InferShardingFromOperands(HloInstruction* instruction,
} }
if (operand->sharding().IsReplicated()) { if (operand->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(HloSharding::Replicate(), return MaybeImproveInstructionSharding(
instruction, HloSharding::Replicate(operand->sharding().metadata()),
may_combine_partial_sharding); instruction, may_combine_partial_sharding);
} }
const auto& tile_assignment = operand->sharding().tile_assignment(); const auto& tile_assignment = operand->sharding().tile_assignment();
@ -1275,6 +1290,7 @@ absl::optional<HloSharding> GetShardingFromUser(
return absl::nullopt; return absl::nullopt;
} }
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
switch (user.opcode()) { switch (user.opcode()) {
case HloOpcode::kBroadcast: { case HloOpcode::kBroadcast: {
if (user.sharding().IsReplicated()) { if (user.sharding().IsReplicated()) {
@ -1338,9 +1354,10 @@ absl::optional<HloSharding> GetShardingFromUser(
auto new_tile_assignment = auto new_tile_assignment =
tile_assignment.Slice(start_indices, end_indices); tile_assignment.Slice(start_indices, end_indices);
if (new_tile_assignment.num_elements() == 1) { if (new_tile_assignment.num_elements() == 1) {
return HloSharding::AssignDevice(*new_tile_assignment.begin()); return HloSharding::AssignDevice(*new_tile_assignment.begin(),
user.sharding().metadata());
} }
return HloSharding::Tile(new_tile_assignment); return HloSharding::Tile(new_tile_assignment, user.sharding().metadata());
} }
case HloOpcode::kConvolution: { case HloOpcode::kConvolution: {
auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user); auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
@ -1378,8 +1395,9 @@ absl::optional<HloSharding> GetShardingFromUser(
} }
case HloOpcode::kReduceWindow: { case HloOpcode::kReduceWindow: {
if (user.shape().IsTuple()) { if (user.shape().IsTuple()) {
return user.sharding().GetSubSharding( auto sub_sharding = user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)}); user.shape(), {user.operand_index(&instruction)});
return sub_sharding;
} }
if (&instruction != user.operand(0)) { if (&instruction != user.operand(0)) {
return absl::nullopt; return absl::nullopt;
@ -1411,8 +1429,9 @@ absl::optional<HloSharding> GetShardingFromUser(
reverse_dimensions); reverse_dimensions);
} }
case HloOpcode::kTuple: { case HloOpcode::kTuple: {
return user.sharding().GetSubSharding(user.shape(), auto sub_sharding = user.sharding().GetSubSharding(
{user.operand_index(&instruction)}); user.shape(), {user.operand_index(&instruction)});
return sub_sharding;
} }
case HloOpcode::kGetTupleElement: { case HloOpcode::kGetTupleElement: {
HloSharding new_sharding = HloSharding new_sharding =
@ -1475,16 +1494,17 @@ absl::optional<HloSharding> GetShardingFromUser(
auto tile_assignment = user_sharding.tile_assignment(); auto tile_assignment = user_sharding.tile_assignment();
tile_assignment.Reshape(target_tile_assignment_dimensions); tile_assignment.Reshape(target_tile_assignment_dimensions);
return user_sharding.ReplicateOnLastTileDim() return user_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment) ? HloSharding::PartialTile(tile_assignment,
: HloSharding::Tile(tile_assignment); user_sharding.metadata())
: HloSharding::Tile(tile_assignment, user_sharding.metadata());
} }
case HloOpcode::kSort: { case HloOpcode::kSort: {
if (user.sharding().IsTuple()) { HloSharding user_sharding = user.sharding();
return user.sharding().GetSubSharding( if (user_sharding.IsTuple()) {
return user_sharding = user_sharding.GetSubSharding(
user.shape(), {user.operand_index(&instruction)}); user.shape(), {user.operand_index(&instruction)});
} else {
return user.sharding();
} }
return user_sharding;
} }
case HloOpcode::kReverse: { case HloOpcode::kReverse: {
return hlo_sharding_util::ReverseSharding(user.sharding(), return hlo_sharding_util::ReverseSharding(user.sharding(),
@ -1555,14 +1575,16 @@ bool InferShardingFromUsers(HloInstruction* instruction,
return false; return false;
} }
// Propagate manual sharding. // Propagate manual sharding.
if (!instruction->has_sharding() && instruction->shape().IsArray() && if (!instruction->has_sharding() && instruction->shape().IsArray()) {
absl::c_any_of(instruction->users(), [](const HloInstruction* user) { for (const HloInstruction* user : instruction->users()) {
return user->has_sharding() && user->sharding().IsManual() && if (!user->has_sharding() || !user->sharding().IsManual() ||
!user->IsCustomCall("SPMDFullToShardShape"); user->IsCustomCall("SPMDFullToShardShape"))
})) { continue;
instruction->set_sharding(HloSharding::Manual()); instruction->set_sharding(
HloSharding::Manual(user->sharding().metadata()));
return true; return true;
} }
}
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
return false; return false;
} }
@ -1579,8 +1601,80 @@ bool InferShardingFromUsers(HloInstruction* instruction,
return improved_sharding; return improved_sharding;
} }
// Checks if two HloShardings have the same metadata attached.
bool SameShardingMetadata(const HloSharding& a, const HloSharding& b) {
DCHECK_EQ(a, b);
auto same_metadata = [](absl::Span<const OpMetadata> a,
absl::Span<const OpMetadata> b) {
if (a.size() != b.size()) return false;
for (int i = 0, e = a.size(); i < e; ++i) {
if (!protobuf_util::ProtobufEquals(a[i], b[i])) {
return false;
}
}
return true;
};
if (a.IsTuple()) {
for (int i = 0, e = a.tuple_elements().size(); i < e; ++i) {
if (!same_metadata(a.tuple_elements()[i].metadata(),
b.tuple_elements()[i].metadata())) {
return false;
}
}
return true;
} else {
return same_metadata(a.metadata(), b.metadata());
}
}
// Assigns metadata to optional sharding on instructions if instructions have
// metadata. If sharding already has some metadata, no new metadata will be
// added.
bool AssignShardingMetadata(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
const auto& metadata = instruction->metadata();
if (!instruction->has_sharding() || metadata.ByteSizeLong() == 0) {
continue;
}
HloSharding sharding_with_metadata =
instruction->sharding().WithMetadata({metadata}, /*overwrite=*/false);
if (!SameShardingMetadata(instruction->sharding(),
sharding_with_metadata)) {
instruction->set_sharding(std::move(sharding_with_metadata));
changed = true;
}
}
}
return changed;
}
// Removes all sharding metadata from shardings on instructions.
bool RemoveShardingMetadata(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (!instruction->has_sharding()) {
continue;
}
HloSharding sharding_no_metadata =
instruction->sharding().WithoutMetadata();
if (!SameShardingMetadata(instruction->sharding(),
sharding_no_metadata)) {
instruction->set_sharding(std::move(sharding_no_metadata));
changed = true;
}
}
}
return changed;
}
// Remove Sharding custom-call instruction by folding the sharding attribute // Remove Sharding custom-call instruction by folding the sharding attribute
// to its operand. If the operand alreayd has a different sharding, insert a // to its operand. If the operand already has a different sharding, insert a
// copy node for reshard. // copy node for reshard.
StatusOr<bool> ProcessShardingInstruction(HloModule* module) { StatusOr<bool> ProcessShardingInstruction(HloModule* module) {
bool changed = false; bool changed = false;
@ -1721,7 +1815,12 @@ Status CheckAndUpdateDeviceAssignmentsInWhileBody(
} }
StatusOr<bool> ShardingPropagation::Run(HloModule* module) { StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool any_changed, ProcessShardingInstruction(module)); bool any_changed = propagate_metadata_ ? AssignShardingMetadata(module)
: RemoveShardingMetadata(module);
auto status_or_changed = ProcessShardingInstruction(module);
if (!status_or_changed.ok()) return status_or_changed;
any_changed |= status_or_changed.ValueOrDie();
// Association of partitionable embedded computations with their parent // Association of partitionable embedded computations with their parent
// instruction. // instruction.

View File

@ -30,7 +30,9 @@ namespace xla {
// a simple local greedy heuristic. // a simple local greedy heuristic.
class ShardingPropagation : public HloModulePass { class ShardingPropagation : public HloModulePass {
public: public:
explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {} explicit ShardingPropagation(bool is_spmd = false,
bool propagate_metadata = false)
: is_spmd_(is_spmd), propagate_metadata_(propagate_metadata) {}
absl::string_view name() const override { return "sharding-propagation"; } absl::string_view name() const override { return "sharding-propagation"; }
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
@ -43,6 +45,7 @@ class ShardingPropagation : public HloModulePass {
private: private:
bool is_spmd_; bool is_spmd_;
bool propagate_metadata_;
}; };
} // namespace xla } // namespace xla

File diff suppressed because it is too large Load Diff