[XLA:SPMD] Minor fixes and utils for manual sharding

PiperOrigin-RevId: 345088593
Change-Id: I33021fb4d739dba8d3dbc9c87d99467ede51a310
This commit is contained in:
Yuanzhong Xu 2020-12-01 13:33:32 -08:00 committed by TensorFlower Gardener
parent fa00092f58
commit d0f7f671bb
8 changed files with 52 additions and 16 deletions

View File

@ -738,14 +738,17 @@ REGISTER_OP("XlaSpmdFullToShardShape")
}
string sharding_attr;
TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
xla::OpSharding sharding;
sharding.ParseFromString(sharding_attr);
if (sharding.type() != xla::OpSharding::OTHER) {
return shape_inference::UnchangedShape(c);
}
std::vector<shape_inference::DimensionHandle> dims;
for (int64 i = 0; i < c->Rank(input_handle); ++i) {
auto dim = c->Value(c->Dim(input_handle, i));
xla::OpSharding sharding;
sharding.ParseFromString(sharding_attr);
int64 partitions_i = sharding.tile_assignment_dimensions(i);
if (dim != shape_inference::InferenceContext::kUnknownDim &&
sharding.type() == xla::OpSharding::OTHER && partitions_i != 1) {
partitions_i != 1) {
dim = (dim + partitions_i - 1) / partitions_i;
}
dims.push_back(c->MakeDim(dim));

View File

@ -140,7 +140,7 @@ Status RewriteLayoutWithShardedShape(
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape) {
if (sharding && !sharding->IsTileMaximal()) {
if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
// After sharding, per core shape might have different layout. For example,
// before sharding, a shape [128, 128] will be assigned default
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,

View File

@ -24,6 +24,12 @@ OpSharding Replicate() {
return result;
}
OpSharding Manual() {
OpSharding result;
result.set_type(OpSharding::MANUAL);
return result;
}
OpSharding AssignDevice(int device) {
OpSharding result;
result.set_type(OpSharding::MAXIMAL);

View File

@ -33,6 +33,9 @@ using TileAssignment = Array<int64>;
// Creates a replicated sharding - replicate a tensor on every device.
OpSharding Replicate();
// Creates a manual sharding - the partitioner will not change the shape.
OpSharding Manual();
// Creates a sharding that assigns a tensor to just one device.
OpSharding AssignDevice(int device);

View File

@ -46,6 +46,16 @@ class Sharding(object):
return Sharding(
proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
@classmethod
def manual(cls):
"""Returns a manuall sharding attribute.
This means the op is manually partitioned by the user and XLA will not
change the shapes.
"""
return Sharding(
proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
@classmethod
def assign_device(cls, core):
"""Returns an AssignDevice sharding attribute.

View File

@ -684,8 +684,13 @@ bool InferShardingFromOperands(HloInstruction* instruction,
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
return false;
}
// Propagate manual sharding.
// Propagate manual sharding. Avoid tuple shaped HLOs that group independent
// together. Reduce and Sort can be tuples but the elements are correlated, so
// we propagate manual sharding through them.
if (!instruction->has_sharding() &&
(instruction->shape().IsArray() ||
instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kSort) &&
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
return op->has_sharding() && op->sharding().IsManual();
})) {
@ -1474,7 +1479,7 @@ bool InferShardingFromUsers(HloInstruction* instruction,
return false;
}
// Propagate manual sharding.
if (!instruction->has_sharding() &&
if (!instruction->has_sharding() && instruction->shape().IsArray() &&
absl::c_any_of(instruction->users(), [](const HloInstruction* user) {
return user->has_sharding() && user->sharding().IsManual() &&
!user->IsCustomCall("SPMDFullToShardShape");

View File

@ -340,6 +340,9 @@ HloInstruction* SpmdBuilder::AddInstruction(
}
PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
if (sharding() == target) {
return *this;
}
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
const bool is_to_replicate =
hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
@ -1314,9 +1317,8 @@ Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
// partitioner will not change the HLOs.
auto manual_to_onedevice = [&](const Shape& shape,
const HloSharding& sharding) {
if (sharding.IsManual()) {
return HloSharding::AssignDevice(0);
}
// If a tuple's elements are all manual, then sharding.IsManual() == True,
// so we test whether it is tuple first.
if (sharding.IsTuple()) {
std::vector<HloSharding> subshardings = sharding.tuple_elements();
for (HloSharding& subsharding : subshardings) {
@ -1326,6 +1328,9 @@ Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
}
return HloSharding::Tuple(shape, subshardings);
}
if (sharding.IsManual()) {
return HloSharding::AssignDevice(0);
}
return sharding;
};
const bool has_manual_sharding =
@ -3896,12 +3901,13 @@ StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
SpmdLogger logger(options_.report_instruction_count);
auto program_shape = module->entry_computation()->ComputeProgramShape();
int64 next_channel_id = hlo_query::NextChannelId(*module);
// Copy the root sharding since the partitioner visitor may temporarily change
// the sharding to work around manual sharding.
HloSharding root_sharding = entry_root->sharding();
TF_ASSIGN_OR_RETURN(
bool partition_changed,
PartitionComputation(
module->entry_computation(),
module->entry_computation()->root_instruction()->sharding(),
&next_channel_id, &logger));
PartitionComputation(module->entry_computation(), root_sharding,
&next_channel_id, &logger));
changed |= partition_changed;
// For the entry computation, make sure that the root instruction and the

View File

@ -4836,7 +4836,9 @@ ENTRY entry {
to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
add = f32[4,2] add(to_shard, param1), sharding={manual}
to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
ROOT mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
to_shard2 = f32[4,2] custom-call(mul), custom_call_target="SPMDFullToShardShape", sharding={manual}
ROOT tuple = (f32[4,2]) tuple(to_shard2), sharding={{manual}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
@ -4845,8 +4847,9 @@ ENTRY entry {
auto p0 = op::GetTupleElement(op::Parameter(0));
auto to_shard = op::Copy(p0);
auto p1 = op::GetTupleElement(op::Parameter(0));
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"),
op::Multiply(op::Copy(op::Add(to_shard, p1)), p0)));
auto mul = AllOf(op::Shape("f32[4,2]"),
op::Multiply(op::Copy(op::Add(to_shard, p1)), p0));
EXPECT_THAT(root, op::Tuple(op::Copy(mul)));
}
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {