[XLA:SPMD] Minor fixes and utils for manual sharding
PiperOrigin-RevId: 345088593 Change-Id: I33021fb4d739dba8d3dbc9c87d99467ede51a310
This commit is contained in:
parent
fa00092f58
commit
d0f7f671bb
@ -738,14 +738,17 @@ REGISTER_OP("XlaSpmdFullToShardShape")
|
||||
}
|
||||
string sharding_attr;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr));
|
||||
xla::OpSharding sharding;
|
||||
sharding.ParseFromString(sharding_attr);
|
||||
if (sharding.type() != xla::OpSharding::OTHER) {
|
||||
return shape_inference::UnchangedShape(c);
|
||||
}
|
||||
std::vector<shape_inference::DimensionHandle> dims;
|
||||
for (int64 i = 0; i < c->Rank(input_handle); ++i) {
|
||||
auto dim = c->Value(c->Dim(input_handle, i));
|
||||
xla::OpSharding sharding;
|
||||
sharding.ParseFromString(sharding_attr);
|
||||
int64 partitions_i = sharding.tile_assignment_dimensions(i);
|
||||
if (dim != shape_inference::InferenceContext::kUnknownDim &&
|
||||
sharding.type() == xla::OpSharding::OTHER && partitions_i != 1) {
|
||||
partitions_i != 1) {
|
||||
dim = (dim + partitions_i - 1) / partitions_i;
|
||||
}
|
||||
dims.push_back(c->MakeDim(dim));
|
||||
|
@ -140,7 +140,7 @@ Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape) {
|
||||
if (sharding && !sharding->IsTileMaximal()) {
|
||||
if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
|
||||
// After sharding, per core shape might have different layout. For example,
|
||||
// before sharding, a shape [128, 128] will be assigned default
|
||||
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
||||
|
@ -24,6 +24,12 @@ OpSharding Replicate() {
|
||||
return result;
|
||||
}
|
||||
|
||||
OpSharding Manual() {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::MANUAL);
|
||||
return result;
|
||||
}
|
||||
|
||||
OpSharding AssignDevice(int device) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::MAXIMAL);
|
||||
|
@ -33,6 +33,9 @@ using TileAssignment = Array<int64>;
|
||||
// Creates a replicated sharding - replicate a tensor on every device.
|
||||
OpSharding Replicate();
|
||||
|
||||
// Creates a manual sharding - the partitioner will not change the shape.
|
||||
OpSharding Manual();
|
||||
|
||||
// Creates a sharding that assigns a tensor to just one device.
|
||||
OpSharding AssignDevice(int device);
|
||||
|
||||
|
@ -46,6 +46,16 @@ class Sharding(object):
|
||||
return Sharding(
|
||||
proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
|
||||
|
||||
@classmethod
|
||||
def manual(cls):
|
||||
"""Returns a manuall sharding attribute.
|
||||
|
||||
This means the op is manually partitioned by the user and XLA will not
|
||||
change the shapes.
|
||||
"""
|
||||
return Sharding(
|
||||
proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
|
||||
|
||||
@classmethod
|
||||
def assign_device(cls, core):
|
||||
"""Returns an AssignDevice sharding attribute.
|
||||
|
@ -684,8 +684,13 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
|
||||
return false;
|
||||
}
|
||||
// Propagate manual sharding.
|
||||
// Propagate manual sharding. Avoid tuple shaped HLOs that group independent
|
||||
// together. Reduce and Sort can be tuples but the elements are correlated, so
|
||||
// we propagate manual sharding through them.
|
||||
if (!instruction->has_sharding() &&
|
||||
(instruction->shape().IsArray() ||
|
||||
instruction->opcode() == HloOpcode::kReduce ||
|
||||
instruction->opcode() == HloOpcode::kSort) &&
|
||||
absl::c_any_of(instruction->operands(), [](const HloInstruction* op) {
|
||||
return op->has_sharding() && op->sharding().IsManual();
|
||||
})) {
|
||||
@ -1474,7 +1479,7 @@ bool InferShardingFromUsers(HloInstruction* instruction,
|
||||
return false;
|
||||
}
|
||||
// Propagate manual sharding.
|
||||
if (!instruction->has_sharding() &&
|
||||
if (!instruction->has_sharding() && instruction->shape().IsArray() &&
|
||||
absl::c_any_of(instruction->users(), [](const HloInstruction* user) {
|
||||
return user->has_sharding() && user->sharding().IsManual() &&
|
||||
!user->IsCustomCall("SPMDFullToShardShape");
|
||||
|
@ -340,6 +340,9 @@ HloInstruction* SpmdBuilder::AddInstruction(
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
|
||||
if (sharding() == target) {
|
||||
return *this;
|
||||
}
|
||||
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
|
||||
const bool is_to_replicate =
|
||||
hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
|
||||
@ -1314,9 +1317,8 @@ Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
|
||||
// partitioner will not change the HLOs.
|
||||
auto manual_to_onedevice = [&](const Shape& shape,
|
||||
const HloSharding& sharding) {
|
||||
if (sharding.IsManual()) {
|
||||
return HloSharding::AssignDevice(0);
|
||||
}
|
||||
// If a tuple's elements are all manual, then sharding.IsManual() == True,
|
||||
// so we test whether it is tuple first.
|
||||
if (sharding.IsTuple()) {
|
||||
std::vector<HloSharding> subshardings = sharding.tuple_elements();
|
||||
for (HloSharding& subsharding : subshardings) {
|
||||
@ -1326,6 +1328,9 @@ Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
|
||||
}
|
||||
return HloSharding::Tuple(shape, subshardings);
|
||||
}
|
||||
if (sharding.IsManual()) {
|
||||
return HloSharding::AssignDevice(0);
|
||||
}
|
||||
return sharding;
|
||||
};
|
||||
const bool has_manual_sharding =
|
||||
@ -3896,12 +3901,13 @@ StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
|
||||
SpmdLogger logger(options_.report_instruction_count);
|
||||
auto program_shape = module->entry_computation()->ComputeProgramShape();
|
||||
int64 next_channel_id = hlo_query::NextChannelId(*module);
|
||||
// Copy the root sharding since the partitioner visitor may temporarily change
|
||||
// the sharding to work around manual sharding.
|
||||
HloSharding root_sharding = entry_root->sharding();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool partition_changed,
|
||||
PartitionComputation(
|
||||
module->entry_computation(),
|
||||
module->entry_computation()->root_instruction()->sharding(),
|
||||
&next_channel_id, &logger));
|
||||
PartitionComputation(module->entry_computation(), root_sharding,
|
||||
&next_channel_id, &logger));
|
||||
changed |= partition_changed;
|
||||
|
||||
// For the entry computation, make sure that the root instruction and the
|
||||
|
@ -4836,7 +4836,9 @@ ENTRY entry {
|
||||
to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
|
||||
add = f32[4,2] add(to_shard, param1), sharding={manual}
|
||||
to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
|
||||
ROOT mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
|
||||
mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
|
||||
to_shard2 = f32[4,2] custom-call(mul), custom_call_target="SPMDFullToShardShape", sharding={manual}
|
||||
ROOT tuple = (f32[4,2]) tuple(to_shard2), sharding={{manual}}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
@ -4845,8 +4847,9 @@ ENTRY entry {
|
||||
auto p0 = op::GetTupleElement(op::Parameter(0));
|
||||
auto to_shard = op::Copy(p0);
|
||||
auto p1 = op::GetTupleElement(op::Parameter(0));
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"),
|
||||
op::Multiply(op::Copy(op::Add(to_shard, p1)), p0)));
|
||||
auto mul = AllOf(op::Shape("f32[4,2]"),
|
||||
op::Multiply(op::Copy(op::Add(to_shard, p1)), p0));
|
||||
EXPECT_THAT(root, op::Tuple(op::Copy(mul)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
|
||||
|
Loading…
Reference in New Issue
Block a user