From 488b5986b88cc3136e6aaa8e844af86b698ffc58 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 25 Nov 2020 10:58:46 -0800 Subject: [PATCH] [XLA:SPMD] Define manual sharding, instead of using fake replication PiperOrigin-RevId: 344282109 Change-Id: I1bfe1cad2f04f427d33133e2670cf334b5f2d38c --- .../kernels/spmd_manual_sharding_ops.cc | 22 +++---- tensorflow/compiler/xla/service/hlo_lexer.cc | 3 + tensorflow/compiler/xla/service/hlo_lexer.h | 1 + tensorflow/compiler/xla/service/hlo_parser.cc | 13 +++- .../compiler/xla/service/hlo_parser_test.cc | 4 +- .../compiler/xla/service/hlo_sharding.cc | 24 +++++-- .../compiler/xla/service/hlo_sharding.h | 30 +++++++-- .../xla/service/sharding_propagation.cc | 25 +++++++ .../xla/service/spmd/spmd_partitioner.cc | 65 +++++++++++++++++-- .../xla/service/spmd/spmd_partitioner.h | 2 + .../xla/service/spmd/spmd_partitioner_test.cc | 17 +++-- tensorflow/compiler/xla/xla_data.proto | 3 + 12 files changed, 173 insertions(+), 36 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc index cd28fe8fa3f..330a11e160f 100644 --- a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -59,7 +59,7 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel { } xla::XlaOp input_annotation; { - // Annotate the full-shape input with the manual sharding. + // Annotate the full-shape input with the sharding. xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), sharding); input_annotation = @@ -68,12 +68,11 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel { } { - // Annotate the shard-shape output with replicated sharding, so that the + // Annotate the shard-shape output with manual sharding, so that the // partitioner will leave it as is. - xla::OpSharding replicated; - replicated.set_type(xla::OpSharding::REPLICATED); - xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), - replicated); + xla::OpSharding manual; + manual.set_type(xla::OpSharding::MANUAL); + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), manual); auto output = xla::CustomCall(ctx->builder(), /*call_target_name=*/"SPMDFullToShardShape", {input_annotation}, output_shape); @@ -112,19 +111,18 @@ class XlaSpmdShardToFullShapeOp : public XlaOpKernel { } xla::XlaOp input_annotation; { - // Annotate the shard-shape input with replicated sharding, so that the + // Annotate the shard-shape input with manual sharding, so that the // partitioner will leave it as is. - xla::OpSharding replicated; - replicated.set_type(xla::OpSharding::REPLICATED); - xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), - replicated); + xla::OpSharding manual; + manual.set_type(xla::OpSharding::MANUAL); + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), manual); input_annotation = xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding", {input}, input_shape_or.ValueOrDie()); } { - // Annotate the full-shape output with the manual sharding. + // Annotate the full-shape output with the sharding. xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), sharding); ctx->SetOutput( diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 3c44b390969..84b766b057c 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -281,6 +281,7 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(manual); KEYWORD(last_tile_dim_replicate); #undef KEYWORD @@ -502,6 +503,8 @@ string TokKindToString(TokKind kind) { return "kw_maximal"; case TokKind::kw_replicated: return "kw_replicated"; + case TokKind::kw_manual: + return "kw_manual"; case TokKind::kw_last_tile_dim_replicate: return "kw_last_tile_dim_replicate"; case TokKind::kw_nan: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 4068ad76581..e2ff65dd8ac 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -61,6 +61,7 @@ enum class TokKind { kw_false, kw_maximal, kw_replicated, + kw_manual, kw_last_tile_dim_replicate, kw_nan, kw_inf, diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index cc2442a1475..558a5029960 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -2252,7 +2252,7 @@ bool HloParserImpl::ParseFrontendAttributes( "expects '}' at the end of frontend attributes"); } -// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? +// ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? '}' // dims ::= int_list device_list ::= int_list bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, @@ -2266,6 +2266,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; + bool manual = false; bool last_tile_dim_replicate = false; std::vector devices; std::vector tile_assignment_dimensions; @@ -2279,6 +2280,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, replicated = true; lexer_.Lex(); break; + case TokKind::kw_manual: + manual = true; + lexer_.Lex(); + break; case TokKind::kAttributeName: { if (lexer_.GetStrVal() == "device") { if (lexer_.Lex() != TokKind::kInt) { @@ -2342,6 +2347,12 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); + } else if (manual) { + if (!devices.empty()) { + return Error(loc, + "manual shardings should not have any devices assigned"); + } + sharding->set_type(OpSharding::MANUAL); } else { if (devices.size() <= 1) { return Error( diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 536433c44e7..dc94e30c847 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -266,10 +266,10 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f R"(HloModule ShardedTupleCreate_module ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { - %v1 = f32[] parameter(0) + %v1 = f32[] parameter(0), sharding={manual} %v2 = f32[3]{0} parameter(1) %v3 = f32[2,3]{1,0} parameter(2) - ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{manual}, {maximal device=0}, {replicated}} } )" diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 6ccfdcc50b8..60d9f7d94fc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -152,6 +152,10 @@ string HloSharding::ToString() const { if (replicated_) { return "{replicated}"; } + + if (manual_) { + return "{manual}"; + } if (maximal_) { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); @@ -169,7 +173,7 @@ bool HloSharding::UsesDevice(int64 device) const { }); } const auto& devices = tile_assignment_; - return replicated_ || absl::c_linear_search(devices, device); + return replicated_ || manual_ || absl::c_linear_search(devices, device); } std::map HloSharding::UsedDevices(int64* count) const { @@ -197,6 +201,7 @@ std::map HloSharding::UsedDevices(int64* count) const { std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!maximal_); + CHECK(!manual_); CHECK(!IsTuple()); std::vector ret_index; tile_assignment_.Each([&](absl::Span index, int64 d) { @@ -213,6 +218,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { int64 HloSharding::DeviceForTileIndex(absl::Span index) const { CHECK(!replicated_); + CHECK(!manual_); CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.begin(); @@ -229,6 +235,7 @@ int64 HloSharding::DeviceForTileIndex(absl::Span index) const { std::vector HloSharding::TileOffsetForDevice(const Shape& shape, int64 device) const { CHECK(!IsTuple()); + CHECK(!manual_); if (maximal_) { return std::vector(shape.dimensions_size(), 0); @@ -250,6 +257,7 @@ std::vector HloSharding::TileOffsetForDevice(const Shape& shape, std::vector HloSharding::TileLimitForDevice(const Shape& shape, int64 device) const { CHECK(!IsTuple()); + CHECK(!manual_); if (maximal_) { return std::vector(shape.dimensions().begin(), @@ -410,7 +418,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return status; } - if (IsTileMaximal()) { + if (IsTileMaximal() || IsManual()) { return Status::OK(); } @@ -447,6 +455,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::REPLICATED) { return Replicate(); + } else if (proto.type() == OpSharding::MANUAL) { + return Manual(); } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } @@ -503,6 +513,8 @@ OpSharding HloSharding::ToProto() const { result.set_type(OpSharding::REPLICATED); } else if (IsTileMaximal()) { result.set_type(OpSharding::MAXIMAL); + } else if (IsManual()) { + result.set_type(OpSharding::MANUAL); } else { result.set_type(OpSharding::OTHER); result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim()); @@ -511,7 +523,7 @@ OpSharding HloSharding::ToProto() const { } Shape HloSharding::TileShape(const Shape& shape) const { - if (IsTileMaximal()) { + if (IsTileMaximal() || IsManual()) { return shape; } Shape result_shape = shape; @@ -523,7 +535,7 @@ Shape HloSharding::TileShape(const Shape& shape) const { } Shape HloSharding::TileShape(const Shape& shape, int64 device) const { - if (IsTileMaximal()) { + if (IsTileMaximal() || IsManual()) { return shape; } @@ -545,6 +557,7 @@ int64 HloSharding::NumTiles() const { if (IsTileMaximal()) { return 1; } + CHECK(!IsManual()); if (ReplicateOnLastTileDim()) { return tile_assignment().num_elements() / tile_assignment().dimensions().back(); @@ -600,6 +613,9 @@ size_t HloSharding::Hash() const { if (replicated_) { return 0; } + if (manual_) { + return 1; + } size_t h = 0; for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index e7ba2bc0680..2b9c8fc74c3 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -42,7 +42,14 @@ class HloSharding { public: // Creates a trivial sharding that replicates a maximal tile across all // devices. - static HloSharding Replicate() { return HloSharding(); } + static HloSharding Replicate() { + return HloSharding(/*manual=*/false, /*replicated=*/true); + } + + // Creates a sharding that represents the op is manually partitioned. + static HloSharding Manual() { + return HloSharding(/*manual=*/true, /*replicated=*/false); + } // Creates a sharding that emulates device placement; a tile shape equal to // the input shape (one tile) assigned to a single device. @@ -128,6 +135,15 @@ class HloSharding { }); } + // Returns whether the sharding represents manual partitioning. + bool IsManual() const { + if (!IsTuple()) { + return manual_; + } + return absl::c_all_of(tuple_elements_, + [](const HloSharding& s) { return s.IsManual(); }); + } + // Returns if the sharding has partial replication and partial sharding. If // true, data is sharded according to other dimensions of tile_assignment(), // but replicated across devices along the last dimension. @@ -209,6 +225,7 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && + manual_ == other.manual_ && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_ && replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_; @@ -248,10 +265,11 @@ class HloSharding { int64 NumTiles() const; private: - HloSharding() - : replicated_(true), - maximal_(true), + explicit HloSharding(bool manual, bool replicated) + : replicated_(replicated), + maximal_(replicated), tuple_(false), + manual_(manual), tile_assignment_({0}), replicate_on_last_tile_dim_(false) {} // device_id values: @@ -264,6 +282,7 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), + manual_(false), tile_assignment_({1}, device_id), replicate_on_last_tile_dim_(false) {} explicit HloSharding(const Array& tile_assignment, @@ -271,12 +290,14 @@ class HloSharding { : replicated_(false), maximal_(false), tuple_(false), + manual_(false), tile_assignment_(tile_assignment), replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {} explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), + manual_(false), tile_assignment_({0}), tuple_elements_(tuple_shardings), replicate_on_last_tile_dim_(false) {} @@ -297,6 +318,7 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; + bool manual_; // This field is only used if replicated_ is false. If maximal_ is true, then // the field contains a rank 1 array with a single element, which is the // device the HLO is assigned to. If maximal_ is false, the field contains an diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index b67d671f377..1542368add8 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -680,6 +680,18 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (!CanPropagateThroughAtAgressiveLevel(*instruction, aggressiveness)) { return false; } + // Do not change manual sharding. + if (instruction->has_sharding() && instruction->sharding().IsManual()) { + return false; + } + // Propagate manual sharding. + if (!instruction->has_sharding() && + absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { + return op->has_sharding() && op->sharding().IsManual(); + })) { + instruction->set_sharding(HloSharding::Manual()); + return true; + } const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { // If an array shaped HLO doesn't support spatial partitioning but at least @@ -1457,6 +1469,19 @@ bool InferShardingFromUsers(HloInstruction* instruction, if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) { return false; } + // Do not change manual sharding. + if (instruction->has_sharding() && instruction->sharding().IsManual()) { + return false; + } + // Propagate manual sharding. + if (!instruction->has_sharding() && + absl::c_any_of(instruction->users(), [](const HloInstruction* user) { + return user->has_sharding() && user->sharding().IsManual() && + !user->IsCustomCall("SPMDFullToShardShape"); + })) { + instruction->set_sharding(HloSharding::Manual()); + return true; + } if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { return false; } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index c63ab75cd01..ebdedd8f95c 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1287,16 +1287,19 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { } } + HloSharding sharding = hlo->sharding().HasUniqueDevice() + ? hlo->sharding() + : HloSharding::Replicate(); + // If the instruction cannot be partitioned, replicate the instruction unless // the instruction has side-effect. std::vector new_operands; for (HloInstruction* operand : hlo->operands()) { - new_operands.push_back( - GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo()); + new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); } auto clone = b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands)); - clone->set_sharding(HloSharding::Replicate()); + clone->set_sharding(sharding); clone->set_metadata(hlo->metadata()); SetPartitionedHlo(hlo, PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) @@ -1307,6 +1310,43 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { visiting_hlo_ = hlo; b_.set_visiting_hlo(hlo); + // Temporarily replace manual sharding to one-device sharding so that the + // partitioner will not change the HLOs. + auto manual_to_onedevice = [&](const Shape& shape, + const HloSharding& sharding) { + if (sharding.IsManual()) { + return HloSharding::AssignDevice(0); + } + if (sharding.IsTuple()) { + std::vector subshardings = sharding.tuple_elements(); + for (HloSharding& subsharding : subshardings) { + if (subsharding.IsManual()) { + subsharding = HloSharding::AssignDevice(0); + } + } + return HloSharding::Tuple(shape, subshardings); + } + return sharding; + }; + const bool has_manual_sharding = + hlo->sharding().IsManual() || + (hlo->sharding().IsTuple() && + absl::c_any_of( + hlo->sharding().tuple_elements(), + [](const HloSharding& sharding) { return sharding.IsManual(); })); + if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) { + visiting_hlo_sharding_ = hlo->sharding(); + hlo->set_sharding( + manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_)); + + visiting_hlo_operand_shardings_.reserve(hlo->operand_count()); + for (auto operand : hlo->operands()) { + visiting_hlo_operand_shardings_.push_back(operand->sharding()); + operand->set_sharding( + manual_to_onedevice(operand->shape(), operand->sharding())); + GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding()); + } + } return Status::OK(); } @@ -1315,6 +1355,18 @@ Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) { b_.derived_instructions(hlo)); visiting_hlo_ = nullptr; b_.set_visiting_hlo(nullptr); + // Revert fake one-device shardings for manually partitioned ops. + if (visiting_hlo_sharding_) { + hlo->set_sharding(*visiting_hlo_sharding_); + GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto operand = hlo->mutable_operand(i); + operand->set_sharding(visiting_hlo_operand_shardings_[i]); + GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding()); + } + visiting_hlo_sharding_.reset(); + visiting_hlo_operand_shardings_.clear(); + } return Status::OK(); } @@ -1865,7 +1917,7 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { CreateR0WithType(hlo->shape().element_type(), 0, &b_)); } auto input = input_partitioned.hlo(); - CHECK(hlo->sharding().IsReplicated()); + CHECK(hlo->sharding().IsManual()); CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape())); auto copy = b_.AddInstruction( HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); @@ -1875,7 +1927,7 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { if (hlo->custom_call_target() == "SPMDShardToFullShape") { // This op switches from manual partitioning to auto partitioning. auto input = GetPartitionedHlo(hlo->operand(0)).hlo(); - CHECK(input->sharding().IsReplicated()); + CHECK(input->sharding().IsManual()); auto copy = b_.AddInstruction( HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); CHECK(ShapeUtil::Compatible( @@ -3927,7 +3979,8 @@ Status SpmdPartitioner::PreprocessSharding(HloModule* module) { hlo->set_sharding( HloSharding::Single(hlo->shape(), HloSharding::Replicate())); } - } else if (!hlo->sharding().IsTileMaximal()) { + } else if (!hlo->sharding().IsTileMaximal() && + !hlo->sharding().IsManual()) { std::vector available(num_partitions_); std::iota(available.begin(), available.end(), 0); TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding( diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index 453eba8bc67..d5a2efd9fc0 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -511,6 +511,8 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { SpmdLogger* logger_; const SpmdPartitionerOptions options_; SpmdPartitioner* partitioner_; + std::vector visiting_hlo_operand_shardings_; + absl::optional visiting_hlo_sharding_; }; } // namespace spmd diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 4c1fb336439..318898f7a5b 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -4830,20 +4830,23 @@ TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { HloModule module ENTRY entry { - param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1} - to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated} - add = f32[4,2] add(to_shard, to_shard), sharding={replicated} + param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}} + param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1} + param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual} + to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual} + add = f32[4,2] add(to_shard, param1), sharding={manual} to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1} - ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1} + ROOT mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); HloInstruction* root = module->entry_computation()->root_instruction(); - auto to_shard = op::Copy(op::Parameter(0)); + auto p0 = op::GetTupleElement(op::Parameter(0)); + auto to_shard = op::Copy(p0); + auto p1 = op::GetTupleElement(op::Parameter(0)); EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"), - op::Multiply(op::Copy(op::Add(to_shard, to_shard)), - op::Parameter(0)))); + op::Multiply(op::Copy(op::Add(to_shard, p1)), p0))); } TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 01de56bf85d..11b39be32ad 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -626,6 +626,9 @@ message OpSharding { TUPLE = 2; // None of the above; tile_shape and tile_assignment are both used. OTHER = 3; + // This op is manually sharded: the shapes are already partitioned and the + // partitioner should not change this op. + MANUAL = 4; } Type type = 1; // The shape of the sharded tile.