[XLA:SPMD] Define manual sharding, instead of using fake replication

PiperOrigin-RevId: 344282109
Change-Id: I1bfe1cad2f04f427d33133e2670cf334b5f2d38c
This commit is contained in:
Yuanzhong Xu 2020-11-25 10:58:46 -08:00 committed by TensorFlower Gardener
parent 29ff0a57fe
commit 488b5986b8
12 changed files with 173 additions and 36 deletions

View File

@ -59,7 +59,7 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel {
}
xla::XlaOp input_annotation;
{
// Annotate the full-shape input with the manual sharding.
// Annotate the full-shape input with the sharding.
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
sharding);
input_annotation =
@ -68,12 +68,11 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel {
}
{
// Annotate the shard-shape output with replicated sharding, so that the
// Annotate the shard-shape output with manual sharding, so that the
// partitioner will leave it as is.
xla::OpSharding replicated;
replicated.set_type(xla::OpSharding::REPLICATED);
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
replicated);
xla::OpSharding manual;
manual.set_type(xla::OpSharding::MANUAL);
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), manual);
auto output = xla::CustomCall(ctx->builder(),
/*call_target_name=*/"SPMDFullToShardShape",
{input_annotation}, output_shape);
@ -112,19 +111,18 @@ class XlaSpmdShardToFullShapeOp : public XlaOpKernel {
}
xla::XlaOp input_annotation;
{
// Annotate the shard-shape input with replicated sharding, so that the
// Annotate the shard-shape input with manual sharding, so that the
// partitioner will leave it as is.
xla::OpSharding replicated;
replicated.set_type(xla::OpSharding::REPLICATED);
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
replicated);
xla::OpSharding manual;
manual.set_type(xla::OpSharding::MANUAL);
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), manual);
input_annotation =
xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding",
{input}, input_shape_or.ValueOrDie());
}
{
// Annotate the full-shape output with the manual sharding.
// Annotate the full-shape output with the sharding.
xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(),
sharding);
ctx->SetOutput(

View File

@ -281,6 +281,7 @@ TokKind HloLexer::LexIdentifier() {
KEYWORD(ROOT);
KEYWORD(maximal);
KEYWORD(replicated);
KEYWORD(manual);
KEYWORD(last_tile_dim_replicate);
#undef KEYWORD
@ -502,6 +503,8 @@ string TokKindToString(TokKind kind) {
return "kw_maximal";
case TokKind::kw_replicated:
return "kw_replicated";
case TokKind::kw_manual:
return "kw_manual";
case TokKind::kw_last_tile_dim_replicate:
return "kw_last_tile_dim_replicate";
case TokKind::kw_nan:

View File

@ -61,6 +61,7 @@ enum class TokKind {
kw_false,
kw_maximal,
kw_replicated,
kw_manual,
kw_last_tile_dim_replicate,
kw_nan,
kw_inf,

View File

@ -2252,7 +2252,7 @@ bool HloParserImpl::ParseFrontendAttributes(
"expects '}' at the end of frontend attributes");
}
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
// ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
// ('devices=' ('[' dims ']')* device_list)? '}'
// dims ::= int_list device_list ::= int_list
bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
@ -2266,6 +2266,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
LocTy loc = lexer_.GetLoc();
bool maximal = false;
bool replicated = false;
bool manual = false;
bool last_tile_dim_replicate = false;
std::vector<int64> devices;
std::vector<int64> tile_assignment_dimensions;
@ -2279,6 +2280,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
replicated = true;
lexer_.Lex();
break;
case TokKind::kw_manual:
manual = true;
lexer_.Lex();
break;
case TokKind::kAttributeName: {
if (lexer_.GetStrVal() == "device") {
if (lexer_.Lex() != TokKind::kInt) {
@ -2342,6 +2347,12 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
}
sharding->set_type(OpSharding::MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else if (manual) {
if (!devices.empty()) {
return Error(loc,
"manual shardings should not have any devices assigned");
}
sharding->set_type(OpSharding::MANUAL);
} else {
if (devices.size() <= 1) {
return Error(

View File

@ -266,10 +266,10 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f
R"(HloModule ShardedTupleCreate_module
ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
%v1 = f32[] parameter(0)
%v1 = f32[] parameter(0), sharding={manual}
%v2 = f32[3]{0} parameter(1)
%v3 = f32[2,3]{1,0} parameter(2)
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{manual}, {maximal device=0}, {replicated}}
}
)"

View File

@ -152,6 +152,10 @@ string HloSharding::ToString() const {
if (replicated_) {
return "{replicated}";
}
if (manual_) {
return "{manual}";
}
if (maximal_) {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
@ -169,7 +173,7 @@ bool HloSharding::UsesDevice(int64 device) const {
});
}
const auto& devices = tile_assignment_;
return replicated_ || absl::c_linear_search(devices, device);
return replicated_ || manual_ || absl::c_linear_search(devices, device);
}
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
@ -197,6 +201,7 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!maximal_);
CHECK(!manual_);
CHECK(!IsTuple());
std::vector<int64> ret_index;
tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
@ -213,6 +218,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
CHECK(!replicated_);
CHECK(!manual_);
CHECK(!IsTuple());
if (maximal_) {
return *tile_assignment_.begin();
@ -229,6 +235,7 @@ int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
int64 device) const {
CHECK(!IsTuple());
CHECK(!manual_);
if (maximal_) {
return std::vector<int64>(shape.dimensions_size(), 0);
@ -250,6 +257,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
int64 device) const {
CHECK(!IsTuple());
CHECK(!manual_);
if (maximal_) {
return std::vector<int64>(shape.dimensions().begin(),
@ -410,7 +418,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return status;
}
if (IsTileMaximal()) {
if (IsTileMaximal() || IsManual()) {
return Status::OK();
}
@ -447,6 +455,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return HloSharding(tuple_shardings);
} else if (proto.type() == OpSharding::REPLICATED) {
return Replicate();
} else if (proto.type() == OpSharding::MANUAL) {
return Manual();
} else if (proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0));
}
@ -503,6 +513,8 @@ OpSharding HloSharding::ToProto() const {
result.set_type(OpSharding::REPLICATED);
} else if (IsTileMaximal()) {
result.set_type(OpSharding::MAXIMAL);
} else if (IsManual()) {
result.set_type(OpSharding::MANUAL);
} else {
result.set_type(OpSharding::OTHER);
result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
@ -511,7 +523,7 @@ OpSharding HloSharding::ToProto() const {
}
Shape HloSharding::TileShape(const Shape& shape) const {
if (IsTileMaximal()) {
if (IsTileMaximal() || IsManual()) {
return shape;
}
Shape result_shape = shape;
@ -523,7 +535,7 @@ Shape HloSharding::TileShape(const Shape& shape) const {
}
Shape HloSharding::TileShape(const Shape& shape, int64 device) const {
if (IsTileMaximal()) {
if (IsTileMaximal() || IsManual()) {
return shape;
}
@ -545,6 +557,7 @@ int64 HloSharding::NumTiles() const {
if (IsTileMaximal()) {
return 1;
}
CHECK(!IsManual());
if (ReplicateOnLastTileDim()) {
return tile_assignment().num_elements() /
tile_assignment().dimensions().back();
@ -600,6 +613,9 @@ size_t HloSharding::Hash() const {
if (replicated_) {
return 0;
}
if (manual_) {
return 1;
}
size_t h = 0;
for (uint32 v : tile_assignment_) {
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));

View File

@ -42,7 +42,14 @@ class HloSharding {
public:
// Creates a trivial sharding that replicates a maximal tile across all
// devices.
static HloSharding Replicate() { return HloSharding(); }
static HloSharding Replicate() {
return HloSharding(/*manual=*/false, /*replicated=*/true);
}
// Creates a sharding that represents the op is manually partitioned.
static HloSharding Manual() {
return HloSharding(/*manual=*/true, /*replicated=*/false);
}
// Creates a sharding that emulates device placement; a tile shape equal to
// the input shape (one tile) assigned to a single device.
@ -128,6 +135,15 @@ class HloSharding {
});
}
// Returns whether the sharding represents manual partitioning.
bool IsManual() const {
if (!IsTuple()) {
return manual_;
}
return absl::c_all_of(tuple_elements_,
[](const HloSharding& s) { return s.IsManual(); });
}
// Returns if the sharding has partial replication and partial sharding. If
// true, data is sharded according to other dimensions of tile_assignment(),
// but replicated across devices along the last dimension.
@ -209,6 +225,7 @@ class HloSharding {
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
manual_ == other.manual_ &&
tile_assignment_ == other.tile_assignment_ &&
tuple_elements_ == other.tuple_elements_ &&
replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_;
@ -248,10 +265,11 @@ class HloSharding {
int64 NumTiles() const;
private:
HloSharding()
: replicated_(true),
maximal_(true),
explicit HloSharding(bool manual, bool replicated)
: replicated_(replicated),
maximal_(replicated),
tuple_(false),
manual_(manual),
tile_assignment_({0}),
replicate_on_last_tile_dim_(false) {}
// device_id values:
@ -264,6 +282,7 @@ class HloSharding {
: replicated_(false),
maximal_(true),
tuple_(false),
manual_(false),
tile_assignment_({1}, device_id),
replicate_on_last_tile_dim_(false) {}
explicit HloSharding(const Array<int64>& tile_assignment,
@ -271,12 +290,14 @@ class HloSharding {
: replicated_(false),
maximal_(false),
tuple_(false),
manual_(false),
tile_assignment_(tile_assignment),
replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {}
explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
: replicated_(false),
maximal_(false),
tuple_(true),
manual_(false),
tile_assignment_({0}),
tuple_elements_(tuple_shardings),
replicate_on_last_tile_dim_(false) {}
@ -297,6 +318,7 @@ class HloSharding {
bool replicated_;
bool maximal_;
bool tuple_;
bool manual_;
// This field is only used if replicated_ is false. If maximal_ is true, then
// the field contains a rank 1 array with a single element, which is the
// device the HLO is assigned to. If maximal_ is false, the field contains an

View File

@ -680,6 +680,18 @@ bool InferShardingFromOperands(HloInstruction* instruction,
if (!CanPropagateThroughAtAgressiveLevel(*instruction, aggressiveness)) {
return false;
}
// Do not change manual sharding.
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
return false;
}
// Propagate manual sharding.
if (!instruction->has_sharding() &&
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
return op->has_sharding() && op->sharding().IsManual();
})) {
instruction->set_sharding(HloSharding::Manual());
return true;
}
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
// If an array shaped HLO doesn't support spatial partitioning but at least
@ -1457,6 +1469,19 @@ bool InferShardingFromUsers(HloInstruction* instruction,
if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) {
return false;
}
// Do not change manual sharding.
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
return false;
}
// Propagate manual sharding.
if (!instruction->has_sharding() &&
absl::c_any_of(instruction->users(), [](const HloInstruction* user) {
return user->has_sharding() && user->sharding().IsManual() &&
!user->IsCustomCall("SPMDFullToShardShape");
})) {
instruction->set_sharding(HloSharding::Manual());
return true;
}
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
return false;
}

View File

@ -1287,16 +1287,19 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
}
}
HloSharding sharding = hlo->sharding().HasUniqueDevice()
? hlo->sharding()
: HloSharding::Replicate();
// If the instruction cannot be partitioned, replicate the instruction unless
// the instruction has side-effect.
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo());
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
auto clone =
b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(HloSharding::Replicate());
clone->set_sharding(sharding);
clone->set_metadata(hlo->metadata());
SetPartitionedHlo(hlo,
PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
@ -1307,6 +1310,43 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
visiting_hlo_ = hlo;
b_.set_visiting_hlo(hlo);
// Temporarily replace manual sharding to one-device sharding so that the
// partitioner will not change the HLOs.
auto manual_to_onedevice = [&](const Shape& shape,
const HloSharding& sharding) {
if (sharding.IsManual()) {
return HloSharding::AssignDevice(0);
}
if (sharding.IsTuple()) {
std::vector<HloSharding> subshardings = sharding.tuple_elements();
for (HloSharding& subsharding : subshardings) {
if (subsharding.IsManual()) {
subsharding = HloSharding::AssignDevice(0);
}
}
return HloSharding::Tuple(shape, subshardings);
}
return sharding;
};
const bool has_manual_sharding =
hlo->sharding().IsManual() ||
(hlo->sharding().IsTuple() &&
absl::c_any_of(
hlo->sharding().tuple_elements(),
[](const HloSharding& sharding) { return sharding.IsManual(); }));
if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
visiting_hlo_sharding_ = hlo->sharding();
hlo->set_sharding(
manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
for (auto operand : hlo->operands()) {
visiting_hlo_operand_shardings_.push_back(operand->sharding());
operand->set_sharding(
manual_to_onedevice(operand->shape(), operand->sharding()));
GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
}
}
return Status::OK();
}
@ -1315,6 +1355,18 @@ Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
b_.derived_instructions(hlo));
visiting_hlo_ = nullptr;
b_.set_visiting_hlo(nullptr);
// Revert fake one-device shardings for manually partitioned ops.
if (visiting_hlo_sharding_) {
hlo->set_sharding(*visiting_hlo_sharding_);
GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
for (int64 i = 0; i < hlo->operand_count(); ++i) {
auto operand = hlo->mutable_operand(i);
operand->set_sharding(visiting_hlo_operand_shardings_[i]);
GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
}
visiting_hlo_sharding_.reset();
visiting_hlo_operand_shardings_.clear();
}
return Status::OK();
}
@ -1865,7 +1917,7 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
CreateR0WithType(hlo->shape().element_type(), 0, &b_));
}
auto input = input_partitioned.hlo();
CHECK(hlo->sharding().IsReplicated());
CHECK(hlo->sharding().IsManual());
CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
@ -1875,7 +1927,7 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
// This op switches from manual partitioning to auto partitioning.
auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
CHECK(input->sharding().IsReplicated());
CHECK(input->sharding().IsManual());
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
CHECK(ShapeUtil::Compatible(
@ -3927,7 +3979,8 @@ Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
hlo->set_sharding(
HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
}
} else if (!hlo->sharding().IsTileMaximal()) {
} else if (!hlo->sharding().IsTileMaximal() &&
!hlo->sharding().IsManual()) {
std::vector<int64> available(num_partitions_);
std::iota(available.begin(), available.end(), 0);
TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(

View File

@ -511,6 +511,8 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
SpmdLogger* logger_;
const SpmdPartitionerOptions options_;
SpmdPartitioner* partitioner_;
std::vector<HloSharding> visiting_hlo_operand_shardings_;
absl::optional<HloSharding> visiting_hlo_sharding_;
};
} // namespace spmd

View File

@ -4830,20 +4830,23 @@ TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
HloModule module
ENTRY entry {
param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1}
to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated}
add = f32[4,2] add(to_shard, to_shard), sharding={replicated}
param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}}
param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1}
param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual}
to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
add = f32[4,2] add(to_shard, param1), sharding={manual}
to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1}
ROOT mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto to_shard = op::Copy(op::Parameter(0));
auto p0 = op::GetTupleElement(op::Parameter(0));
auto to_shard = op::Copy(p0);
auto p1 = op::GetTupleElement(op::Parameter(0));
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"),
op::Multiply(op::Copy(op::Add(to_shard, to_shard)),
op::Parameter(0))));
op::Multiply(op::Copy(op::Add(to_shard, p1)), p0)));
}
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {

View File

@ -626,6 +626,9 @@ message OpSharding {
TUPLE = 2;
// None of the above; tile_shape and tile_assignment are both used.
OTHER = 3;
// This op is manually sharded: the shapes are already partitioned and the
// partitioner should not change this op.
MANUAL = 4;
}
Type type = 1;
// The shape of the sharded tile.