[XLA:SPMD] Define manual sharding, instead of using fake replication
PiperOrigin-RevId: 344282109 Change-Id: I1bfe1cad2f04f427d33133e2670cf334b5f2d38c
This commit is contained in:
parent
29ff0a57fe
commit
488b5986b8
@ -59,7 +59,7 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel {
|
||||
}
|
||||
xla::XlaOp input_annotation;
|
||||
{
|
||||
// Annotate the full-shape input with the manual sharding.
|
||||
// Annotate the full-shape input with the sharding.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
sharding);
|
||||
input_annotation =
|
||||
@ -68,12 +68,11 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
{
|
||||
// Annotate the shard-shape output with replicated sharding, so that the
|
||||
// Annotate the shard-shape output with manual sharding, so that the
|
||||
// partitioner will leave it as is.
|
||||
xla::OpSharding replicated;
|
||||
replicated.set_type(xla::OpSharding::REPLICATED);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
replicated);
|
||||
xla::OpSharding manual;
|
||||
manual.set_type(xla::OpSharding::MANUAL);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), manual);
|
||||
auto output = xla::CustomCall(ctx->builder(),
|
||||
/*call_target_name=*/"SPMDFullToShardShape",
|
||||
{input_annotation}, output_shape);
|
||||
@ -112,19 +111,18 @@ class XlaSpmdShardToFullShapeOp : public XlaOpKernel {
|
||||
}
|
||||
xla::XlaOp input_annotation;
|
||||
{
|
||||
// Annotate the shard-shape input with replicated sharding, so that the
|
||||
// Annotate the shard-shape input with manual sharding, so that the
|
||||
// partitioner will leave it as is.
|
||||
xla::OpSharding replicated;
|
||||
replicated.set_type(xla::OpSharding::REPLICATED);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
replicated);
|
||||
xla::OpSharding manual;
|
||||
manual.set_type(xla::OpSharding::MANUAL);
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), manual);
|
||||
input_annotation =
|
||||
xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding",
|
||||
{input}, input_shape_or.ValueOrDie());
|
||||
}
|
||||
|
||||
{
|
||||
// Annotate the full-shape output with the manual sharding.
|
||||
// Annotate the full-shape output with the sharding.
|
||||
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
|
||||
sharding);
|
||||
ctx->SetOutput(
|
||||
|
@ -281,6 +281,7 @@ TokKind HloLexer::LexIdentifier() {
|
||||
KEYWORD(ROOT);
|
||||
KEYWORD(maximal);
|
||||
KEYWORD(replicated);
|
||||
KEYWORD(manual);
|
||||
KEYWORD(last_tile_dim_replicate);
|
||||
|
||||
#undef KEYWORD
|
||||
@ -502,6 +503,8 @@ string TokKindToString(TokKind kind) {
|
||||
return "kw_maximal";
|
||||
case TokKind::kw_replicated:
|
||||
return "kw_replicated";
|
||||
case TokKind::kw_manual:
|
||||
return "kw_manual";
|
||||
case TokKind::kw_last_tile_dim_replicate:
|
||||
return "kw_last_tile_dim_replicate";
|
||||
case TokKind::kw_nan:
|
||||
|
@ -61,6 +61,7 @@ enum class TokKind {
|
||||
kw_false,
|
||||
kw_maximal,
|
||||
kw_replicated,
|
||||
kw_manual,
|
||||
kw_last_tile_dim_replicate,
|
||||
kw_nan,
|
||||
kw_inf,
|
||||
|
@ -2252,7 +2252,7 @@ bool HloParserImpl::ParseFrontendAttributes(
|
||||
"expects '}' at the end of frontend attributes");
|
||||
}
|
||||
|
||||
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
|
||||
// ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
|
||||
// ('devices=' ('[' dims ']')* device_list)? '}'
|
||||
// dims ::= int_list device_list ::= int_list
|
||||
bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
|
||||
@ -2266,6 +2266,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
|
||||
LocTy loc = lexer_.GetLoc();
|
||||
bool maximal = false;
|
||||
bool replicated = false;
|
||||
bool manual = false;
|
||||
bool last_tile_dim_replicate = false;
|
||||
std::vector<int64> devices;
|
||||
std::vector<int64> tile_assignment_dimensions;
|
||||
@ -2279,6 +2280,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
|
||||
replicated = true;
|
||||
lexer_.Lex();
|
||||
break;
|
||||
case TokKind::kw_manual:
|
||||
manual = true;
|
||||
lexer_.Lex();
|
||||
break;
|
||||
case TokKind::kAttributeName: {
|
||||
if (lexer_.GetStrVal() == "device") {
|
||||
if (lexer_.Lex() != TokKind::kInt) {
|
||||
@ -2342,6 +2347,12 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
|
||||
}
|
||||
sharding->set_type(OpSharding::MAXIMAL);
|
||||
sharding->add_tile_assignment_devices(devices[0]);
|
||||
} else if (manual) {
|
||||
if (!devices.empty()) {
|
||||
return Error(loc,
|
||||
"manual shardings should not have any devices assigned");
|
||||
}
|
||||
sharding->set_type(OpSharding::MANUAL);
|
||||
} else {
|
||||
if (devices.size() <= 1) {
|
||||
return Error(
|
||||
|
@ -266,10 +266,10 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f
|
||||
R"(HloModule ShardedTupleCreate_module
|
||||
|
||||
ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
|
||||
%v1 = f32[] parameter(0)
|
||||
%v1 = f32[] parameter(0), sharding={manual}
|
||||
%v2 = f32[3]{0} parameter(1)
|
||||
%v3 = f32[2,3]{1,0} parameter(2)
|
||||
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
|
||||
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{manual}, {maximal device=0}, {replicated}}
|
||||
}
|
||||
|
||||
)"
|
||||
|
@ -152,6 +152,10 @@ string HloSharding::ToString() const {
|
||||
if (replicated_) {
|
||||
return "{replicated}";
|
||||
}
|
||||
|
||||
if (manual_) {
|
||||
return "{manual}";
|
||||
}
|
||||
if (maximal_) {
|
||||
return StrCat(
|
||||
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
|
||||
@ -169,7 +173,7 @@ bool HloSharding::UsesDevice(int64 device) const {
|
||||
});
|
||||
}
|
||||
const auto& devices = tile_assignment_;
|
||||
return replicated_ || absl::c_linear_search(devices, device);
|
||||
return replicated_ || manual_ || absl::c_linear_search(devices, device);
|
||||
}
|
||||
|
||||
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
||||
@ -197,6 +201,7 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
||||
|
||||
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
|
||||
CHECK(!maximal_);
|
||||
CHECK(!manual_);
|
||||
CHECK(!IsTuple());
|
||||
std::vector<int64> ret_index;
|
||||
tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
|
||||
@ -213,6 +218,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
|
||||
|
||||
int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
|
||||
CHECK(!replicated_);
|
||||
CHECK(!manual_);
|
||||
CHECK(!IsTuple());
|
||||
if (maximal_) {
|
||||
return *tile_assignment_.begin();
|
||||
@ -229,6 +235,7 @@ int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
|
||||
std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
|
||||
int64 device) const {
|
||||
CHECK(!IsTuple());
|
||||
CHECK(!manual_);
|
||||
|
||||
if (maximal_) {
|
||||
return std::vector<int64>(shape.dimensions_size(), 0);
|
||||
@ -250,6 +257,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
|
||||
std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
|
||||
int64 device) const {
|
||||
CHECK(!IsTuple());
|
||||
CHECK(!manual_);
|
||||
|
||||
if (maximal_) {
|
||||
return std::vector<int64>(shape.dimensions().begin(),
|
||||
@ -410,7 +418,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
return status;
|
||||
}
|
||||
|
||||
if (IsTileMaximal()) {
|
||||
if (IsTileMaximal() || IsManual()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -447,6 +455,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
return HloSharding(tuple_shardings);
|
||||
} else if (proto.type() == OpSharding::REPLICATED) {
|
||||
return Replicate();
|
||||
} else if (proto.type() == OpSharding::MANUAL) {
|
||||
return Manual();
|
||||
} else if (proto.tile_assignment_devices().size() == 1) {
|
||||
return HloSharding(proto.tile_assignment_devices(0));
|
||||
}
|
||||
@ -503,6 +513,8 @@ OpSharding HloSharding::ToProto() const {
|
||||
result.set_type(OpSharding::REPLICATED);
|
||||
} else if (IsTileMaximal()) {
|
||||
result.set_type(OpSharding::MAXIMAL);
|
||||
} else if (IsManual()) {
|
||||
result.set_type(OpSharding::MANUAL);
|
||||
} else {
|
||||
result.set_type(OpSharding::OTHER);
|
||||
result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
|
||||
@ -511,7 +523,7 @@ OpSharding HloSharding::ToProto() const {
|
||||
}
|
||||
|
||||
Shape HloSharding::TileShape(const Shape& shape) const {
|
||||
if (IsTileMaximal()) {
|
||||
if (IsTileMaximal() || IsManual()) {
|
||||
return shape;
|
||||
}
|
||||
Shape result_shape = shape;
|
||||
@ -523,7 +535,7 @@ Shape HloSharding::TileShape(const Shape& shape) const {
|
||||
}
|
||||
|
||||
Shape HloSharding::TileShape(const Shape& shape, int64 device) const {
|
||||
if (IsTileMaximal()) {
|
||||
if (IsTileMaximal() || IsManual()) {
|
||||
return shape;
|
||||
}
|
||||
|
||||
@ -545,6 +557,7 @@ int64 HloSharding::NumTiles() const {
|
||||
if (IsTileMaximal()) {
|
||||
return 1;
|
||||
}
|
||||
CHECK(!IsManual());
|
||||
if (ReplicateOnLastTileDim()) {
|
||||
return tile_assignment().num_elements() /
|
||||
tile_assignment().dimensions().back();
|
||||
@ -600,6 +613,9 @@ size_t HloSharding::Hash() const {
|
||||
if (replicated_) {
|
||||
return 0;
|
||||
}
|
||||
if (manual_) {
|
||||
return 1;
|
||||
}
|
||||
size_t h = 0;
|
||||
for (uint32 v : tile_assignment_) {
|
||||
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
|
||||
|
@ -42,7 +42,14 @@ class HloSharding {
|
||||
public:
|
||||
// Creates a trivial sharding that replicates a maximal tile across all
|
||||
// devices.
|
||||
static HloSharding Replicate() { return HloSharding(); }
|
||||
static HloSharding Replicate() {
|
||||
return HloSharding(/*manual=*/false, /*replicated=*/true);
|
||||
}
|
||||
|
||||
// Creates a sharding that represents the op is manually partitioned.
|
||||
static HloSharding Manual() {
|
||||
return HloSharding(/*manual=*/true, /*replicated=*/false);
|
||||
}
|
||||
|
||||
// Creates a sharding that emulates device placement; a tile shape equal to
|
||||
// the input shape (one tile) assigned to a single device.
|
||||
@ -128,6 +135,15 @@ class HloSharding {
|
||||
});
|
||||
}
|
||||
|
||||
// Returns whether the sharding represents manual partitioning.
|
||||
bool IsManual() const {
|
||||
if (!IsTuple()) {
|
||||
return manual_;
|
||||
}
|
||||
return absl::c_all_of(tuple_elements_,
|
||||
[](const HloSharding& s) { return s.IsManual(); });
|
||||
}
|
||||
|
||||
// Returns if the sharding has partial replication and partial sharding. If
|
||||
// true, data is sharded according to other dimensions of tile_assignment(),
|
||||
// but replicated across devices along the last dimension.
|
||||
@ -209,6 +225,7 @@ class HloSharding {
|
||||
|
||||
bool operator==(const HloSharding& other) const {
|
||||
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
|
||||
manual_ == other.manual_ &&
|
||||
tile_assignment_ == other.tile_assignment_ &&
|
||||
tuple_elements_ == other.tuple_elements_ &&
|
||||
replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_;
|
||||
@ -248,10 +265,11 @@ class HloSharding {
|
||||
int64 NumTiles() const;
|
||||
|
||||
private:
|
||||
HloSharding()
|
||||
: replicated_(true),
|
||||
maximal_(true),
|
||||
explicit HloSharding(bool manual, bool replicated)
|
||||
: replicated_(replicated),
|
||||
maximal_(replicated),
|
||||
tuple_(false),
|
||||
manual_(manual),
|
||||
tile_assignment_({0}),
|
||||
replicate_on_last_tile_dim_(false) {}
|
||||
// device_id values:
|
||||
@ -264,6 +282,7 @@ class HloSharding {
|
||||
: replicated_(false),
|
||||
maximal_(true),
|
||||
tuple_(false),
|
||||
manual_(false),
|
||||
tile_assignment_({1}, device_id),
|
||||
replicate_on_last_tile_dim_(false) {}
|
||||
explicit HloSharding(const Array<int64>& tile_assignment,
|
||||
@ -271,12 +290,14 @@ class HloSharding {
|
||||
: replicated_(false),
|
||||
maximal_(false),
|
||||
tuple_(false),
|
||||
manual_(false),
|
||||
tile_assignment_(tile_assignment),
|
||||
replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {}
|
||||
explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
|
||||
: replicated_(false),
|
||||
maximal_(false),
|
||||
tuple_(true),
|
||||
manual_(false),
|
||||
tile_assignment_({0}),
|
||||
tuple_elements_(tuple_shardings),
|
||||
replicate_on_last_tile_dim_(false) {}
|
||||
@ -297,6 +318,7 @@ class HloSharding {
|
||||
bool replicated_;
|
||||
bool maximal_;
|
||||
bool tuple_;
|
||||
bool manual_;
|
||||
// This field is only used if replicated_ is false. If maximal_ is true, then
|
||||
// the field contains a rank 1 array with a single element, which is the
|
||||
// device the HLO is assigned to. If maximal_ is false, the field contains an
|
||||
|
@ -680,6 +680,18 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
if (!CanPropagateThroughAtAgressiveLevel(*instruction, aggressiveness)) {
|
||||
return false;
|
||||
}
|
||||
// Do not change manual sharding.
|
||||
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
|
||||
return false;
|
||||
}
|
||||
// Propagate manual sharding.
|
||||
if (!instruction->has_sharding() &&
|
||||
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
|
||||
return op->has_sharding() && op->sharding().IsManual();
|
||||
})) {
|
||||
instruction->set_sharding(HloSharding::Manual());
|
||||
return true;
|
||||
}
|
||||
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
|
||||
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
|
||||
// If an array shaped HLO doesn't support spatial partitioning but at least
|
||||
@ -1457,6 +1469,19 @@ bool InferShardingFromUsers(HloInstruction* instruction,
|
||||
if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) {
|
||||
return false;
|
||||
}
|
||||
// Do not change manual sharding.
|
||||
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
|
||||
return false;
|
||||
}
|
||||
// Propagate manual sharding.
|
||||
if (!instruction->has_sharding() &&
|
||||
absl::c_any_of(instruction->users(), [](const HloInstruction* user) {
|
||||
return user->has_sharding() && user->sharding().IsManual() &&
|
||||
!user->IsCustomCall("SPMDFullToShardShape");
|
||||
})) {
|
||||
instruction->set_sharding(HloSharding::Manual());
|
||||
return true;
|
||||
}
|
||||
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1287,16 +1287,19 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
}
|
||||
}
|
||||
|
||||
HloSharding sharding = hlo->sharding().HasUniqueDevice()
|
||||
? hlo->sharding()
|
||||
: HloSharding::Replicate();
|
||||
|
||||
// If the instruction cannot be partitioned, replicate the instruction unless
|
||||
// the instruction has side-effect.
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* operand : hlo->operands()) {
|
||||
new_operands.push_back(
|
||||
GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo());
|
||||
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
|
||||
}
|
||||
auto clone =
|
||||
b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
|
||||
clone->set_sharding(HloSharding::Replicate());
|
||||
clone->set_sharding(sharding);
|
||||
clone->set_metadata(hlo->metadata());
|
||||
SetPartitionedHlo(hlo,
|
||||
PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
|
||||
@ -1307,6 +1310,43 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
|
||||
visiting_hlo_ = hlo;
|
||||
b_.set_visiting_hlo(hlo);
|
||||
// Temporarily replace manual sharding to one-device sharding so that the
|
||||
// partitioner will not change the HLOs.
|
||||
auto manual_to_onedevice = [&](const Shape& shape,
|
||||
const HloSharding& sharding) {
|
||||
if (sharding.IsManual()) {
|
||||
return HloSharding::AssignDevice(0);
|
||||
}
|
||||
if (sharding.IsTuple()) {
|
||||
std::vector<HloSharding> subshardings = sharding.tuple_elements();
|
||||
for (HloSharding& subsharding : subshardings) {
|
||||
if (subsharding.IsManual()) {
|
||||
subsharding = HloSharding::AssignDevice(0);
|
||||
}
|
||||
}
|
||||
return HloSharding::Tuple(shape, subshardings);
|
||||
}
|
||||
return sharding;
|
||||
};
|
||||
const bool has_manual_sharding =
|
||||
hlo->sharding().IsManual() ||
|
||||
(hlo->sharding().IsTuple() &&
|
||||
absl::c_any_of(
|
||||
hlo->sharding().tuple_elements(),
|
||||
[](const HloSharding& sharding) { return sharding.IsManual(); }));
|
||||
if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
|
||||
visiting_hlo_sharding_ = hlo->sharding();
|
||||
hlo->set_sharding(
|
||||
manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
|
||||
|
||||
visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
|
||||
for (auto operand : hlo->operands()) {
|
||||
visiting_hlo_operand_shardings_.push_back(operand->sharding());
|
||||
operand->set_sharding(
|
||||
manual_to_onedevice(operand->shape(), operand->sharding()));
|
||||
GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1315,6 +1355,18 @@ Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
|
||||
b_.derived_instructions(hlo));
|
||||
visiting_hlo_ = nullptr;
|
||||
b_.set_visiting_hlo(nullptr);
|
||||
// Revert fake one-device shardings for manually partitioned ops.
|
||||
if (visiting_hlo_sharding_) {
|
||||
hlo->set_sharding(*visiting_hlo_sharding_);
|
||||
GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
auto operand = hlo->mutable_operand(i);
|
||||
operand->set_sharding(visiting_hlo_operand_shardings_[i]);
|
||||
GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
|
||||
}
|
||||
visiting_hlo_sharding_.reset();
|
||||
visiting_hlo_operand_shardings_.clear();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1865,7 +1917,7 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
CreateR0WithType(hlo->shape().element_type(), 0, &b_));
|
||||
}
|
||||
auto input = input_partitioned.hlo();
|
||||
CHECK(hlo->sharding().IsReplicated());
|
||||
CHECK(hlo->sharding().IsManual());
|
||||
CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
|
||||
auto copy = b_.AddInstruction(
|
||||
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
|
||||
@ -1875,7 +1927,7 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
|
||||
// This op switches from manual partitioning to auto partitioning.
|
||||
auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
|
||||
CHECK(input->sharding().IsReplicated());
|
||||
CHECK(input->sharding().IsManual());
|
||||
auto copy = b_.AddInstruction(
|
||||
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
|
||||
CHECK(ShapeUtil::Compatible(
|
||||
@ -3927,7 +3979,8 @@ Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
|
||||
hlo->set_sharding(
|
||||
HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
|
||||
}
|
||||
} else if (!hlo->sharding().IsTileMaximal()) {
|
||||
} else if (!hlo->sharding().IsTileMaximal() &&
|
||||
!hlo->sharding().IsManual()) {
|
||||
std::vector<int64> available(num_partitions_);
|
||||
std::iota(available.begin(), available.end(), 0);
|
||||
TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
|
||||
|
@ -511,6 +511,8 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
SpmdLogger* logger_;
|
||||
const SpmdPartitionerOptions options_;
|
||||
SpmdPartitioner* partitioner_;
|
||||
std::vector<HloSharding> visiting_hlo_operand_shardings_;
|
||||
absl::optional<HloSharding> visiting_hlo_sharding_;
|
||||
};
|
||||
|
||||
} // namespace spmd
|
||||
|
@ -4830,20 +4830,23 @@ TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1}
|
||||
to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated}
|
||||
add = f32[4,2] add(to_shard, to_shard), sharding={replicated}
|
||||
param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}}
|
||||
param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1}
|
||||
param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual}
|
||||
to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
|
||||
add = f32[4,2] add(to_shard, param1), sharding={manual}
|
||||
to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
|
||||
ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1}
|
||||
ROOT mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
auto to_shard = op::Copy(op::Parameter(0));
|
||||
auto p0 = op::GetTupleElement(op::Parameter(0));
|
||||
auto to_shard = op::Copy(p0);
|
||||
auto p1 = op::GetTupleElement(op::Parameter(0));
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"),
|
||||
op::Multiply(op::Copy(op::Add(to_shard, to_shard)),
|
||||
op::Parameter(0))));
|
||||
op::Multiply(op::Copy(op::Add(to_shard, p1)), p0)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
|
||||
|
@ -626,6 +626,9 @@ message OpSharding {
|
||||
TUPLE = 2;
|
||||
// None of the above; tile_shape and tile_assignment are both used.
|
||||
OTHER = 3;
|
||||
// This op is manually sharded: the shapes are already partitioned and the
|
||||
// partitioner should not change this op.
|
||||
MANUAL = 4;
|
||||
}
|
||||
Type type = 1;
|
||||
// The shape of the sharded tile.
|
||||
|
Loading…
Reference in New Issue
Block a user