A. Unique TensorFlower 629a241b55 [XLA:SPMD] Apply all to all when resharding from/to partial replicate at some
special cases.

PiperOrigin-RevId: 348091566
Change-Id: I178f5dc9fe83cb9bcf45a87998cb755691935b4f
2020-12-17 14:19:58 -08:00

4145 lines
174 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include <float.h>
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/numbers.h"
namespace xla {
namespace spmd {
string SpmdLogger::MakeReport() {
string report;
absl::StrAppend(&report,
"\n\n***** SPMD memory during transformation *****\n");
std::sort(entries_.begin(), entries_.end(),
[](auto const& entry0, auto const& entry1) {
return entry0.first > entry1.first;
});
for (int64 i = 0;
i < std::min<int64>(report_instruction_count_, entries_.size()); ++i) {
absl::StrAppend(
&report, "\n ",
tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
entries_[i].second, "\n");
}
return report;
}
void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
const std::vector<HloInstruction*>& group) {
string report = hlo->ToString();
int64 max_value = -1;
for (HloInstruction* inst : group) {
if (!inst->shape().IsArray()) {
continue;
}
max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
absl::StrAppend(&report, " * ", inst->ToString(), "\n");
}
entries_.push_back(std::make_pair(max_value, report));
}
/* static */ string SpmdLogger::ReportBeforePartition(
const HloModule& module, int64 report_instruction_count) {
string report;
absl::StrAppend(&report,
"\n\n***** SPMD memory usage before partition *****\n");
absl::StrAppend(&report, "\n ** Replicated instructions\n");
absl::StrAppend(&report, ReportMemoryUsage(
module,
[](const HloInstruction* hlo) {
return !hlo->has_sharding() ||
hlo->sharding().IsReplicated();
},
report_instruction_count));
absl::StrAppend(&report, "\n ** All instructions\n");
absl::StrAppend(&report,
ReportMemoryUsage(
module, [](const HloInstruction* hlo) { return true; },
report_instruction_count));
return report;
}
/* static */ string SpmdLogger::ReportAfterPartition(
const HloModule& module, int64 report_instruction_count) {
string report;
absl::StrAppend(&report,
"\n\n***** SPMD memory usage after partition *****\n");
absl::StrAppend(&report,
ReportMemoryUsage(
module, [](const HloInstruction* hlo) { return true; },
report_instruction_count));
return report;
}
template <typename F>
/* static */ string SpmdLogger::ReportMemoryUsage(
const HloModule& module, const F& filter, int64 report_instruction_count) {
string report;
std::vector<HloInstruction*> instructions;
instructions.reserve(module.instruction_count());
for (auto computation : module.computations()) {
if (computation->IsFusionComputation()) {
continue;
}
for (auto hlo : computation->instructions()) {
if (hlo->shape().IsTuple() ||
ShapeUtil::IsEffectiveScalar(hlo->shape())) {
continue;
}
if (filter(hlo)) {
instructions.push_back(hlo);
}
}
}
const auto add_report = [&](std::vector<HloInstruction*>* insts) {
std::sort(insts->begin(), insts->end(),
[](const HloInstruction* inst0, const HloInstruction* inst1) {
return ShapeSizeInBytes(inst0->shape()) >
ShapeSizeInBytes(inst1->shape());
});
for (int64 i = 0;
i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
absl::StrAppend(&report, " ",
tensorflow::strings::HumanReadableNumBytes(
ShapeSizeInBytes((*insts)[i]->shape())),
" : ", (*insts)[i]->ToString(), "\n");
}
};
add_report(&instructions);
return report;
}
namespace {
// Clears all sharding attributes from instructions in the module. This must be
// called only after all SPMD transformation is complete.
Status ClearShardingAttributes(HloModule* module) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* hlo : computation->instructions()) {
// Keep sharding annotation on Infeed and entry parameters since they're
// used by HloReplicationAnalysis later (for ArCrsCombiner).
if (hlo->opcode() == HloOpcode::kInfeed) {
continue;
}
if (hlo->opcode() == HloOpcode::kParameter &&
computation == module->entry_computation()) {
continue;
}
hlo->clear_sharding();
}
}
return Status::OK();
}
std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
const HloSharding& sharding, absl::Span<const int64> replication_dims) {
int64 group_size = 1;
for (int64 i : replication_dims) {
group_size *= sharding.tile_assignment().dim(i);
}
std::vector<std::vector<int64>> partition_groups(
sharding.tile_assignment().num_elements() / group_size);
sharding.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 partition) {
int64 group_id = 0;
for (int64 i = 0; i < indices.size(); ++i) {
if (!absl::c_linear_search(replication_dims, i)) {
group_id *= sharding.tile_assignment().dim(i);
group_id += indices[i];
}
}
partition_groups[group_id].push_back(partition);
});
return partition_groups;
}
} // namespace
HloInstruction* SpmdBuilder::AddInstruction(
std::unique_ptr<HloInstruction> instruction) {
HloInstruction* hlo =
HloComputation::Builder::AddInstruction(std::move(instruction));
if (visiting_hlo_) {
instructions_[visiting_hlo_].push_back(hlo);
}
if (hlo->opcode() == HloOpcode::kBroadcast) {
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_linear_search(hlo->dimensions(), i)) {
broadcast_dims_[hlo].insert(i);
}
}
}
if (hlo->IsElementwise() && hlo->operand_count() > 0) {
absl::flat_hash_set<int64> broadcast_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
broadcast_dims.insert(i);
}
for (int64 i = 0; i < hlo->operand_count(); ++i) {
auto it = broadcast_dims_.find(hlo->operand(i));
if (it == broadcast_dims_.end()) {
broadcast_dims.clear();
break;
}
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!it->second.contains(i)) {
broadcast_dims.erase(i);
}
}
}
if (!broadcast_dims.empty()) {
broadcast_dims_[hlo] = std::move(broadcast_dims);
}
}
if (hlo->opcode() == HloOpcode::kTranspose) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
absl::flat_hash_set<int64> xpose_broadcast_dims;
std::vector<int64> reverse_map(hlo->shape().rank());
for (int64 i = 0; i < reverse_map.size(); ++i) {
reverse_map[hlo->dimensions(i)] = i;
}
for (int64 dim : it->second) {
xpose_broadcast_dims.insert(reverse_map[dim]);
}
broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
}
}
if (hlo->opcode() == HloOpcode::kReshape &&
Product(hlo->shape().dimensions()) > 0) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
absl::flat_hash_set<int64> reshape_broadcast_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
reshape_broadcast_dims.insert(i);
}
std::vector<int64> before_dim_size_stack;
std::vector<int64> after_dim_size_stack;
for (int64 i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) {
before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
}
for (int64 i = hlo->shape().rank() - 1; i >= 0; --i) {
after_dim_size_stack.push_back(hlo->shape().dimensions(i));
}
while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
int64 before_size = before_dim_size_stack.back();
int64 after_size = after_dim_size_stack.back();
int64 current_before_dim =
hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
int64 current_after_dim =
hlo->shape().rank() - after_dim_size_stack.size();
before_dim_size_stack.pop_back();
after_dim_size_stack.pop_back();
if (!it->second.contains(current_before_dim)) {
reshape_broadcast_dims.erase(current_after_dim);
}
if (before_size == after_size) {
continue;
}
if (before_size % after_size == 0) {
// Split dim.
before_dim_size_stack.push_back(before_size / after_size);
} else if (after_size % before_size == 0) {
// Merge dim.
after_dim_size_stack.push_back(after_size / before_size);
} else {
// Other cases, mark all remaining dims as non-broadcast.
for (int64 i = current_after_dim; i < hlo->shape().rank(); ++i) {
reshape_broadcast_dims.erase(i);
}
break;
}
}
if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
reshape_broadcast_dims.clear();
}
if (!reshape_broadcast_dims.empty()) {
broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
}
}
}
if (hlo->opcode() == HloOpcode::kSlice ||
hlo->opcode() == HloOpcode::kDynamicSlice) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
auto dims = it->second;
broadcast_dims_[hlo] = std::move(dims);
}
}
if (hlo->opcode() == HloOpcode::kPad) {
auto it = broadcast_dims_.find(hlo->operand(0));
if (it != broadcast_dims_.end()) {
absl::flat_hash_set<int64> pad_broadcast_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
const auto& dim = hlo->padding_config().dimensions(i);
if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
dim.interior_padding() == 0 && it->second.contains(i)) {
pad_broadcast_dims.insert(i);
}
}
if (!pad_broadcast_dims.empty()) {
broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
}
}
}
return hlo;
}
PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
if (sharding() == target) {
return *this;
}
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
const bool is_to_replicate =
hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
for (auto& entry : cache) {
if (entry.first == target) {
return entry.second;
}
}
}
auto resharded = ReshardNoCache(target);
state_.reshard_cache->per_hlo_cache[resharded.hlo()]
.reshard_cache.emplace_back(sharding(), *this);
if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
cache.emplace_back(target, std::move(resharded));
return cache.back().second;
}
return resharded;
}
PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
VLOG(2) << "Resharding " << hlo_->ToString() << " from "
<< hlo_->sharding().ToString() << " to " << target.ToString();
const Shape& shape = hlo_->shape();
if (shape.element_type() == TOKEN) {
return *this;
}
CHECK(shape.IsTuple() || !target.IsTuple());
// Tuple shape instructions may have non-tuple sharding, which means that the
// same sharding applies to all the leaves.
if (shape.IsTuple() && !target.IsTuple()) {
return Reshard(target.GetTupleSharding(shape).ValueOrDie());
}
// For a tuple shape, recursively apply Reshard to all the leaves and return
// a tuple instruction.
if (shape.IsTuple()) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
auto element = state_.b->AddInstruction(
HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
element->set_sharding(sharding().GetSubSharding(shape, {i}));
elements.push_back(
PartitionedHlo(
element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
.Reshard(target.GetSubSharding(shape, {i}))
.hlo());
}
auto tuple =
state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
tuple->set_sharding(target);
return PartitionedHlo(tuple, base_shape_, state_);
}
if (sharding() == target) {
return *this;
}
if (CanReshardWithCollectivePermute(sharding(), target)) {
return ReshardWithCollectivePermute(target);
}
if (auto src_tgt_dims =
GetReshardAllToAllSourceTargetDims(sharding(), target)) {
return ReshardWithAllToAll(target, *src_tgt_dims);
}
if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
if (try_reshard.has_value()) {
return try_reshard.value();
}
try_reshard = ReshardPartialReplicateWithAllToAll(target);
if (try_reshard.has_value()) {
return try_reshard.value();
}
}
if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
if (try_reshard.has_value()) {
return try_reshard.value();
}
try_reshard = ReshardPartialReplicateWithAllToAll(target);
if (try_reshard.has_value()) {
return try_reshard.value();
}
}
// If not replicated yet, first replicate and then reshard to use one of the
// two implementations below.
if (!sharding().IsReplicated()) {
return Replicate().Reshard(target);
}
// 'Replicated' to 'SingleDevice'.
if (target.IsTileMaximal()) {
auto copy = state_.b->AddInstruction(
HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
copy->set_sharding(target);
return PartitionedHlo(copy, base_shape_, state_);
}
// 'Replicated' to partial replicated.
if (target.ReplicateOnLastTileDim()) {
std::vector<int64> group_dims(target.tile_assignment().num_dimensions() -
1);
std::iota(group_dims.begin(), group_dims.end(), 0);
auto target_grouped = GroupShardingOnDims(target, group_dims);
auto partially_sharded = PerGroupSliceFromReplicated(
hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
target_grouped.group_dim_sizes, state_.b);
partially_sharded->set_sharding(target);
return PartitionedHlo(partially_sharded, base_shape(), state_);
}
// 'Replicated' to 'Tiled'.
auto padded_hlo =
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
auto shard_shape = MakePartitionedShape(shape, target);
auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, padded_hlo,
MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
shard_shape.dimensions()));
slice->set_sharding(target);
return PartitionedHlo(slice, base_shape_, state_);
}
PartitionedHlo PartitionedHlo::PadWithValue(
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
absl::Span<const int64> skipped_dims) const {
const HloSharding& sharding = hlo_->sharding();
const Shape& shape = hlo_->shape();
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
return *this;
}
CHECK(!sharding.IsTileMaximal());
auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) {
// Comparison: iota + start_index < valid_size
auto iota =
state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
auto broadcast_start_index = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(index_shape, start_index, {}));
auto index_in_full_shape =
state_.b->AddInstruction(HloInstruction::CreateBinary(
index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
ComparisonDirection direction = ComparisonDirection::kLt;
int64 index_limit = base_shape_.dimensions(dim);
if (absl::c_linear_search(left_padded_dims, dim)) {
direction = ComparisonDirection::kGe;
index_limit =
index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
index_limit;
}
auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(index_limit)));
auto broadcast_limit = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(index_shape, limit, {}));
return state_.b->AddInstruction(HloInstruction::CreateCompare(
mask_shape, index_in_full_shape, broadcast_limit, direction));
};
HloInstruction* mask = nullptr;
auto offsets = MakePartitionOffsets(base_shape_, sharding,
state_.partition_id, state_.b);
for (int64 i = 0; i < shape.rank(); ++i) {
if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
absl::c_linear_search(skipped_dims, i)) {
continue;
}
if (mask == nullptr) {
mask = get_mask_for_dim(i, offsets[i]);
} else {
mask = state_.b->AddInstruction(
HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
get_mask_for_dim(i, offsets[i])));
}
}
if (mask == nullptr) {
return *this;
}
auto broadcast_pad_value = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(shape, pad_value, {}));
auto result = state_.b->AddInstruction(HloInstruction::CreateTernary(
shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
result->set_sharding(sharding);
return PartitionedHlo(result, base_shape_, state_);
}
absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
PartitionedHlo::ReshardAsWindowedInput(const Window& window,
const HloSharding& target,
HloInstruction* pad_value,
bool mask_invalid_region) {
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
for (auto& entry : cache) {
if (std::get<0>(entry) == target &&
protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
return std::get<2>(entry);
}
}
auto update_cache = [&](WindowedInputShardReturnValue result) {
cache.emplace_back(target, window, std::move(result));
return std::get<2>(cache.back());
};
VLOG(2) << "ReshardAsWindowedInput()\n"
<< "\twindow:" << window_util::ToString(window)
<< "\ttarget sharding:" << target.ToString();
CHECK(!target.IsTileMaximal());
auto partition_ordinals =
MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
auto shard_shape = base_shape_;
std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
base_shape_.rank());
std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
base_shape_.rank());
std::vector<HloInstruction*> dynamic_slice_offset_on_output(
base_shape_.rank(), nullptr);
Window shard_window = window;
auto padded_shape = base_shape_;
std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
std::vector<int64> per_shard_window_counts(base_shape_.rank());
std::vector<int64> explicit_left_padding(base_shape_.rank());
for (int64 i = 0; i < base_shape_.rank(); ++i) {
// Do not pad non-partitioned dimensions.
int64 shard_count = target.tile_assignment().dim(i);
if (shard_count == 1) {
offsets_on_padded_shape[i] = state_.b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
continue;
}
const auto& wd = window.dimensions(i);
const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
int64 full_size =
base_shape_.dimensions(i) +
(wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
wd.padding_high() + wd.padding_low();
if (full_size < dilated_size) {
VLOG(2) << "Failed to reshard window operand because the window size is "
"larger than padded base size";
return absl::nullopt;
}
int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
if (wd.stride() != 1 &&
(wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
// TODO(yuanzx): Support this case.
VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
return absl::nullopt;
}
// We use explicit padding for full dilations, then use padding_low and
// padding_high on the sharded op for the remaining. padding_low and
// padding_high are now given initial values, which will be later updated if
// dilation is not 1.
auto swd = shard_window.mutable_dimensions(i);
explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
swd->set_padding_low(wd.padding_low() % wd.base_dilation());
swd->set_padding_high(0);
// Calculation for the first element needed on the 'padded-but-not-dilated'
// shape. The start on the dilated shape could be a hole, so we add
// wd.base_dilation() - 1 to the constant term to skip the leading holes.
start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
wd.stride() * per_shard_window_counts[i],
wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
int64 dilated_shard_size =
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
wd.stride() * per_shard_window_counts[i],
dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
wd.base_dilation());
offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
partition_ordinals[i], state_.b);
auto shard_size_function =
limit_on_padded_calculations[i] - start_on_padded_calculations[i];
int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count);
shard_shape.set_dimensions(i, max_shard_size);
padded_shape.set_dimensions(
i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
// For base dilation, calculate the needed padding_low and padding_high, as
// well as the offset for the output if a dynamic slice is needed after the
// sharded op.
if (wd.base_dilation() != 1) {
// Returns the offset of a shard's first valid element in the dilated
// shard.
auto get_first_valid_element_offset_on_dilated_shard =
[&](int64 shard_ordinal) {
return start_on_padded_calculations[i].Calculate(shard_ordinal) *
wd.base_dilation() +
swd->padding_low() -
wd.stride() * per_shard_window_counts[i] * shard_ordinal;
};
CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
swd->padding_low());
// Determine swd->padding_high.
for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
++shard_ordinal) {
int64 wanted_limit_on_dilated_shard =
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
int64 actual_limit_on_dilated_shard_without_pad_high =
get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
(max_shard_size - 1) * wd.base_dilation() + 1;
swd->set_padding_high(std::max<int64>(
swd->padding_high(),
wanted_limit_on_dilated_shard -
actual_limit_on_dilated_shard_without_pad_high));
}
// Determine swd->padding_low and output dynamic slice index.
if (wd.stride() == 1) {
int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0);
bool all_same = true;
for (int64 shard_ordinal = 1; shard_ordinal < shard_count;
++shard_ordinal) {
int64 start =
get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
if (start != swd->padding_low()) {
all_same = false;
}
max_pad_low = std::max(max_pad_low, start);
}
if (!all_same) {
auto start_on_padded_input =
start_on_padded_calculations[i].Calculate(partition_ordinals[i],
state_.b);
// We will calculate
// max_pad_low - (first_window - required_first_window)
// which equals
// required_first_window - (first_window - max_pad_low)
auto first_window_minus_max_pad_low =
MultiplyAddDivideOffsetCalculation(
wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
.Calculate(start_on_padded_input, state_.b);
auto required_first_window =
MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
1)
.Calculate(partition_ordinals[i], state_.b);
dynamic_slice_offset_on_output[i] =
state_.b->AddInstruction(HloInstruction::CreateBinary(
required_first_window->shape(), HloOpcode::kSubtract,
required_first_window, first_window_minus_max_pad_low));
}
swd->set_padding_low(max_pad_low);
} else {
if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
0) {
// General base dilation not yet implemented.
return absl::nullopt;
}
// padding_low on all shards should equal the initially assigned
// swd->padding_low(), i.e., the padding_low() on the original window.
}
}
}
// Returns the output dynamic slice offset when needed, and absl::nullopt
// otherwise.
auto get_dynamic_slice_offset_on_output_if_needed =
[&]() -> absl::optional<std::vector<HloInstruction*>> {
if (absl::c_all_of(
dynamic_slice_offset_on_output,
[](HloInstruction* offset) { return offset == nullptr; })) {
return absl::nullopt;
}
auto zero = state_.b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
if (dynamic_slice_offset_on_output[i] == nullptr) {
dynamic_slice_offset_on_output[i] = zero;
}
}
return dynamic_slice_offset_on_output;
};
// If the currrent HLO is replicated, pad then slice.
if (sharding().IsReplicated()) {
PaddingConfig padding_config;
for (int64 i = 0; i < base_shape_.rank(); ++i) {
auto padding_config_dim = padding_config.add_dimensions();
padding_config_dim->set_interior_padding(0);
// Do not pad non-partitioned dimensions.
if (target.tile_assignment().dim(i) == 1) {
padding_config_dim->set_edge_padding_low(0);
padding_config_dim->set_edge_padding_high(0);
continue;
}
padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
explicit_left_padding[i] -
base_shape_.dimensions(i));
}
auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_)
? hlo_
: state_.b->AddInstruction(HloInstruction::CreatePad(
padded_shape, hlo_, pad_value, padding_config));
auto sharded_input =
state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, padded_hlo, offsets_on_padded_shape,
shard_shape.dimensions()));
return update_cache(WindowedInputShardReturnValue{
sharded_input, shard_window,
get_dynamic_slice_offset_on_output_if_needed()});
}
if (target != sharding()) {
return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
}
// Halo exchange.
HloInstruction* visiting_hlo = hlo_;
auto original_shard_shape = MakePartitionedShape(base_shape_, target);
std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
// TODO(yuanzx): We are concatenating on each sharded dimension one at time,
// and in the second dimension (and beyond) we create halos by slicing the
// concat in the previous dimension, which is not optimal. We should generate
// halos only concating slices, instead of slicing concats.
for (int dim = 0; dim < base_shape_.rank(); ++dim) {
int64 shard_count = target.tile_assignment().dim(dim);
if (shard_count == 1) {
continue;
}
int64 input_shard_size =
CeilOfRatio(base_shape_.dimensions(dim), shard_count);
// Left halo. The size of the halo is derived by subtracting the first read
// element offset of the i'th partition from the limit of the (i-1)'th
// partition.
MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
input_shard_size, explicit_left_padding[dim], 1);
left_halo_size_functions[dim] =
shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
// Right halo.
MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
right_halo_size_functions[dim] =
limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
auto resharded = ExchangeHaloAndGetValidData(
visiting_hlo, base_shape_, left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding[dim],
padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target,
offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim],
state_.collective_ops_creator, state_.next_channel_id, state_.b,
mask_invalid_region);
if (!resharded) {
VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
"is beyond the neighbor.";
return Replicate().ReshardAsWindowedInput(window, target, pad_value);
}
visiting_hlo = *resharded;
}
return update_cache(WindowedInputShardReturnValue{
visiting_hlo, shard_window,
get_dynamic_slice_offset_on_output_if_needed()});
}
PartitionedHlo PartitionedHlo::Replicate() {
auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
if (state_.partitioner->options().cache_all_gather) {
for (auto& entry : cache) {
if (entry.first.IsReplicated()) {
return entry.second;
}
}
}
const HloSharding& sharding = hlo_->sharding();
const Shape& shape = hlo_->shape();
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
if (sharding.IsReplicated()) {
return *this;
}
for (auto& entry : cache) {
if (entry.first.IsReplicated()) {
return entry.second;
}
}
auto update_cache = [&](PartitionedHlo resharded) {
state_.reshard_cache->per_hlo_cache[resharded.hlo()]
.reshard_cache.emplace_back(sharding, *this);
if (state_.partitioner->options().cache_all_gather) {
cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
return cache.back().second;
}
return resharded;
};
// 'Single Device' to 'Repliated'.
if (sharding.IsTileMaximal()) {
return update_cache(Broadcast());
}
// 'Tiled' to 'Replicated'.
std::vector<int64> all_dims(shape.rank());
std::iota(all_dims.begin(), all_dims.end(), 0);
HloInstruction* result = ReplicatePartial(all_dims);
result->set_sharding(HloSharding::Replicate());
return update_cache(PartitionedHlo(result, base_shape_, state_));
}
HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
CHECK(!sharding().IsTileMaximal());
const Shape& shard_shape = hlo()->shape();
Shape target_shape = shard_shape;
Shape padded_target_shape = shard_shape;
for (int64 i : dims) {
padded_target_shape.set_dimensions(
i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
target_shape.set_dimensions(i, base_shape().dimensions(i));
}
HloInstruction* result = nullptr;
if (state_.collective_ops_creator.create_cross_partition_all_gather) {
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(),
NewChannel(), dims,
state_.collective_ops_creator);
}
if (result == nullptr) {
auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(shard_shape.element_type())));
auto zero_bcast = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
state_.partition_id, state_.b, dims);
auto dus =
state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
padded_target_shape, zero_bcast, hlo_, offsets));
HloComputation* reduction =
MakeBinaryAdd(shard_shape.element_type(), state_.module);
auto all_reduce =
state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, dus, reduction,
GetPartitionGroupsForReplication(sharding(), dims), NewChannel());
result = all_reduce;
}
if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
std::vector<int64> start_indices(target_shape.rank(), 0);
std::vector<int64> strides(target_shape.rank(), 1);
result = state_.b->AddInstruction(
HloInstruction::CreateSlice(target_shape, result, start_indices,
target_shape.dimensions(), strides));
}
return result;
}
absl::optional<PartitionedHlo>
PartitionedHlo::ReshardToPartialReplicateWithAllGather(
const HloSharding& target) {
if (!target.ReplicateOnLastTileDim()) {
return absl::nullopt;
}
// Tiled/partial replicate to partial replicate
// Get the comptible sharding to target with resharding by all reduce.
auto compatible_sharding =
PartialReplicateReshardCompatibleSharding(target, sharding());
if (!compatible_sharding.has_value()) {
return absl::nullopt;
}
const auto& temp_sharding = compatible_sharding.value();
auto partitioned_hlo = *this;
// Use collective permute to adjust device assignment if needed.
if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
partitioned_hlo =
partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
}
// Get replicate dims and replicate factor of each dimensions.
int64 rank = hlo_->shape().rank();
std::vector<int64> replicate_dims;
std::vector<int64> replicate_factors;
for (int64 dim = 0; dim < rank; dim++) {
int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) /
target.tile_assignment().dim(dim);
if (replicate_factor > 1) {
replicate_dims.emplace_back(dim);
replicate_factors.emplace_back(replicate_factor);
}
}
// Do left halo exchange if all-reduce directly will remove useful data
// from the source.
auto halo_exchange = TileToPartialReplicateHaloExchange(
partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
partitioned_hlo.state().collective_ops_creator,
partitioned_hlo.state().next_channel_id,
partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
if (!halo_exchange.has_value()) {
return absl::nullopt;
}
auto halo_exchange_hlo = halo_exchange.value();
// Grouped on replicate dimensions.
auto sharding_grouped =
GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
partitioned_hlo.state(), sharding_grouped.device_groups,
partitioned_hlo.state().b);
auto base_shape = MakePartitionedShape(base_shape_, target);
// It's possible that halo_exchange_hlo == hlo.hlo().
// Record the sharding of hlo here, and reset it before return.
auto original_sharding = partitioned_hlo.sharding();
halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
per_group_partitioner_state);
HloInstruction* result =
partial_replicate_hlo.ReplicatePartial(replicate_dims);
partitioned_hlo.hlo()->set_sharding(original_sharding);
result->set_sharding(target);
return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
}
absl::optional<PartitionedHlo>
PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
const HloSharding& target) {
if (!sharding().ReplicateOnLastTileDim()) {
return absl::nullopt;
}
// Get the temp sharding target from partial replicate to target tile dims.
// target_compatible_sharding has the same tile_assignment dimensions
// as the target and can reshard to target by collective permute.
// target_compatible_sharding could have different device assignment as
// targe. sharding() can reshard to target_compatible_sharding by
// dynamic slice.
auto target_compatible_sharding =
PartialReplicateReshardCompatibleSharding(sharding(), target);
// Reshard to target_compatible_sharding by dynamic slice.
if (!target_compatible_sharding.has_value()) {
return absl::nullopt;
}
std::vector<int64> expand_tile_dims;
std::vector<int64> tiling_dim_factors;
int64 rank = hlo_->shape().rank();
tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
const auto& temp_target_sharding = target_compatible_sharding.value();
for (int64 dim = 0; dim < rank; dim++) {
if (temp_target_sharding.tile_assignment().dim(dim) >
sharding().tile_assignment().dim(dim)) {
expand_tile_dims.push_back(dim);
}
tiling_dim_factors.emplace_back(
temp_target_sharding.tile_assignment().dim(dim) /
sharding().tile_assignment().dim(dim));
}
// Add another dimension in tiling_dim_factors if target is partial replicate.
if (target.ReplicateOnLastTileDim()) {
tiling_dim_factors.emplace_back(
target.tile_assignment().dimensions().back());
}
// 2. Get the padded_hlo, do right halo exchange if needed.
auto padded_hlo = PadFromPartialReplicateShape(
hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
state_.collective_ops_creator, state_.next_channel_id,
state_.partition_id, state_.b);
if (!padded_hlo.has_value()) {
return absl::nullopt;
}
// 3. Slice out the tile from replicate ones.
auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
// Since we are just slicing, we can just use the differences between the new
// and old offsets in the full shape as the dynamic-slice offsets.
auto padded_base_shape = shard_shape;
for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
padded_base_shape.set_dimensions(
i, padded_base_shape.dimensions(i) *
temp_target_sharding.tile_assignment().dim(i));
}
auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
state_.partition_id, state_.b);
auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
state_.partition_id, state_.b);
for (int64 i = 0; i < offsets.size(); ++i) {
offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
}
auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
slice->set_sharding(temp_target_sharding);
auto result = PartitionedHlo(slice, base_shape_, state_);
// If temp_target_sharding's device assignment is different from target,
// use collective permute to reshard.
if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
return result.ReshardWithCollectivePermute(target);
}
// If device assignment in temp_target_sharding and target are the same,
// return result directly.
return result;
}
PartitionedHlo PartitionedHlo::Broadcast() const {
const Shape& shape = hlo_->shape();
const HloSharding& sharding = hlo_->sharding();
CHECK(sharding.HasUniqueDevice());
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(sharding.GetUniqueDevice())));
Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
bcast_shape,
state_.b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
ComparisonDirection::kEq)),
{}));
auto zero = state_.b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
auto zero_bcast = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(shape, zero, {}));
auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
HloComputation* reduction =
MakeBinaryAdd(shape.element_type(), state_.module);
auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, operand, reduction, {}, NewChannel());
result->set_sharding(HloSharding::Replicate());
return PartitionedHlo(result, base_shape_, state_);
}
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
const HloSharding& target,
absl::Span<const std::pair<int64, int64>> source_target_dims) const {
if (source_target_dims.empty()) {
if (target == sharding()) {
return *this;
}
// If the device order is different in the target, fix the order with
// ReshardWithCollectivePermute.
return ReshardWithCollectivePermute(target);
}
// Swap one pair of dimensions.
int64 source_dim = source_target_dims[0].first;
int64 target_dim = source_target_dims[0].second;
const int64 group_size = sharding().tile_assignment().dim(source_dim) /
sharding().tile_assignment().dim(target_dim);
auto temp_target_tile = sharding().tile_assignment();
{
std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
int64 i = 0;
int64 added_source_dim = -1;
int64 added_target_dim = -1;
for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) {
if (source_dim == j) {
reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
reshape_tile_dims[++i] = group_size;
added_source_dim = i;
} else if (target_dim == j) {
reshape_tile_dims[i] = temp_target_tile.dim(j);
reshape_tile_dims[++i] = 1;
added_target_dim = i;
} else {
reshape_tile_dims[i] = temp_target_tile.dim(j);
}
++i;
}
temp_target_tile.Reshape(reshape_tile_dims);
std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
xpose_dims[added_source_dim] = added_target_dim;
xpose_dims[added_target_dim] = added_source_dim;
temp_target_tile = hlo_sharding_util::TransposeSharding(
HloSharding::Tile(temp_target_tile), xpose_dims)
.tile_assignment();
auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
temp_target_tile_dims[source_dim] =
sharding().tile_assignment().dim(target_dim);
temp_target_tile_dims[target_dim] =
sharding().tile_assignment().dim(source_dim);
temp_target_tile.Reshape(temp_target_tile_dims);
}
auto temp_target = target.ReplicateOnLastTileDim()
? HloSharding::PartialTile(temp_target_tile)
: HloSharding::Tile(temp_target_tile);
auto padded_shape = hlo_->shape();
padded_shape.set_dimensions(
target_dim,
RoundUpToNearest(padded_shape.dimensions(target_dim),
temp_target.tile_assignment().dim(target_dim)));
auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
// The order of ids in the group must follow the temp_target sharding.
std::vector<std::vector<int64>> groups(
temp_target.tile_assignment().num_elements() / group_size);
temp_target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
int64 group_id = 0;
for (int64 dim = 0; dim < indices.size(); ++dim) {
if (dim == target_dim) {
group_id *= temp_target.tile_assignment().dim(dim) / group_size;
group_id += indices[dim] / group_size;
} else {
group_id *= temp_target.tile_assignment().dim(dim);
group_id += indices[dim];
}
}
groups[group_id].push_back(device);
});
HloInstruction* result = nullptr;
// Split along the split dimension (target_dim) of the all-to-all
// output.
std::vector<int64> dimensions;
for (int64 i = 0; i < base_shape_.rank(); ++i) {
if (i == target_dim) {
dimensions.push_back(group_size);
dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
} else {
dimensions.push_back(padded_hlo->shape().dimensions(i));
}
}
auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
padded_hlo));
// After the reshape, it is guaranteed to have at least 3 dimensions.
auto all_to_all =
state_.collective_ops_creator.create_cross_partition_all_to_all(
state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
// Reorder the split dimension of the reshape to be located in front of the
// input partition dimension, so the two dimensions can be combined.
int64 new_source_dim =
(target_dim < source_dim) ? source_dim + 1 : source_dim;
std::vector<int64> permutation;
for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
if (i == target_dim) {
continue;
}
if (i == new_source_dim) {
permutation.push_back(target_dim);
}
permutation.push_back(i);
}
auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
.ValueOrDie(),
all_to_all, permutation));
// Combine the split dimension and the input partition dimension.
auto new_shape = ShapeInference::InferAllToAllShape(
padded_hlo->shape(), target_dim, source_dim, group_size)
.ValueOrDie();
result = state_.b->AddInstruction(
HloInstruction::CreateReshape(new_shape, transpose));
const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
if (result_shape != result->shape()) {
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
result_shape, result, std::vector<int64>(result_shape.rank(), 0),
result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
}
result->set_sharding(temp_target);
auto remaining_source_target_dims = source_target_dims;
remaining_source_target_dims.remove_prefix(1);
return PartitionedHlo(result, base_shape_, state_)
.ReshardWithAllToAll(target, remaining_source_target_dims);
}
absl::optional<PartitionedHlo>
PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
const auto& partial_replicate_sharding =
source_is_partial_replicate ? sharding() : target;
// If neither the source nor the target is partial replicate, return null.
if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
return absl::nullopt;
}
const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
// If both source and target are partial replicate, should be supported in
// Reshard with AllToAll already.
if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
return absl::nullopt;
}
// Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
// to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
// the last tile dim will be replicate first before all-to-all.
// Or resharding from
// sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
// to sharding={devices=[2,3]0,1,2,3,4,5}, where
// the last tile dim will be sharded after all-to-all.
const int num_replicas =
partial_replicate_sharding.tile_assignment().dimensions().back();
if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
partial_replicate_sharding.tile_assignment().num_dimensions()) ||
(partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
return absl::nullopt;
}
int to_replicate_dim = -1;
for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
--i) {
if (tile_sharding.tile_assignment().dim(i) > 1 &&
(to_replicate_dim == -1)) {
if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
return absl::nullopt;
}
to_replicate_dim = i;
}
if (tile_sharding.tile_assignment().dim(i) !=
partial_replicate_sharding.tile_assignment().dim(i + 1)) {
return absl::nullopt;
}
}
if (to_replicate_dim == -1) {
return absl::nullopt;
}
// Check if core assignments for source and the target are the same.
auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
return absl::nullopt;
}
auto tmp_tile_assignment = tile_sharding.tile_assignment();
auto tmp_tile_assignment_dimensions =
tile_sharding.tile_assignment().dimensions();
tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
tmp_tile_assignment_dimensions.push_back(num_replicas);
tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
auto tmp_partial_replicate_sharding =
HloSharding::PartialTile(tmp_tile_assignment);
if (source_is_partial_replicate) {
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
sharding(), tmp_partial_replicate_sharding)) {
auto partitioned_hlo =
ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
return partitioned_hlo.Reshard(target);
}
} else {
auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
partitioned_hlo.sharding(), target)) {
return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
}
}
return absl::nullopt;
}
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
const HloSharding& target) const {
CHECK(CanReshardWithCollectivePermute(sharding(), target))
<< sharding().ToString() << " to " << target.ToString();
if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
if (!(*broadcast_dims)->empty()) {
// If hlo() has broadcast dims, check if data is already the same between
// source/destination pairs.
std::vector<int64> broadcast_dims_vector;
for (int64 i = 0; i < hlo()->shape().rank(); ++i) {
if ((*broadcast_dims)->contains(i)) {
broadcast_dims_vector.push_back(i);
}
}
if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
sharding(), broadcast_dims_vector) ==
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
target, broadcast_dims_vector)) {
auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
hlo()->shape(), HloOpcode::kCopy, hlo()));
copy->set_sharding(target);
return PartitionedHlo(copy, base_shape_, state_);
}
}
}
std::vector<std::pair<int64, int64>> src_dst_pairs;
sharding().tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 src_device) {
int64 dst_device = target.tile_assignment()(indices);
src_dst_pairs.emplace_back(src_device, dst_device);
});
auto cp =
state_.collective_ops_creator.create_cross_partition_collective_permute(
state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
cp->set_sharding(target);
return PartitionedHlo(cp, base_shape_, state_);
}
SpmdPartitioningVisitor::SpmdPartitioningVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options,
SpmdPartitioner* partitioner)
: changed_(false),
module_(computation->parent()),
num_partitions_(num_partitions),
num_replicas_(num_replicas),
collective_ops_creator_(collective_ops_creator),
next_channel_id_(next_channel_id),
b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
partition_id_(collective_ops_creator_.create_partition_id(&b_)),
logger_(logger),
options_(std::move(options)),
partitioner_(partitioner) {}
Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
if (hlo->HasSideEffect()) {
return Unimplemented("Side-effect ops cannot be replicated: %s",
hlo->ToString());
}
if (hlo->IsElementwise() && hlo->operand_count() > 0) {
return HandleElementwise(hlo);
}
if (!hlo->sharding().IsTileMaximal()) {
VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
<< hlo->ToString();
for (int64 i = 0; i < hlo->operand_count(); ++i) {
VLOG(1) << " operand " << i
<< " sharding:" << hlo->operand(i)->sharding().ToString();
}
}
HloSharding sharding = hlo->sharding().HasUniqueDevice()
? hlo->sharding()
: HloSharding::Replicate();
// If the instruction cannot be partitioned, replicate the instruction unless
// the instruction has side-effect.
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
auto clone =
b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(sharding);
clone->set_metadata(hlo->metadata());
SetPartitionedHlo(hlo,
PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding()));
return Status::OK();
}
Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
visiting_hlo_ = hlo;
b_.set_visiting_hlo(hlo);
// Temporarily replace manual sharding to one-device sharding so that the
// partitioner will not change the HLOs.
auto manual_to_onedevice = [&](const Shape& shape,
const HloSharding& sharding) {
// If a tuple's elements are all manual, then sharding.IsManual() == True,
// so we test whether it is tuple first.
if (sharding.IsTuple()) {
std::vector<HloSharding> subshardings = sharding.tuple_elements();
for (HloSharding& subsharding : subshardings) {
if (subsharding.IsManual()) {
subsharding = HloSharding::AssignDevice(0);
}
}
return HloSharding::Tuple(shape, subshardings);
}
if (sharding.IsManual()) {
return HloSharding::AssignDevice(0);
}
return sharding;
};
const bool has_manual_sharding =
hlo->sharding().IsManual() ||
(hlo->sharding().IsTuple() &&
absl::c_any_of(
hlo->sharding().tuple_elements(),
[](const HloSharding& sharding) { return sharding.IsManual(); }));
if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
visiting_hlo_sharding_ = hlo->sharding();
hlo->set_sharding(
manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
for (auto operand : hlo->operands()) {
visiting_hlo_operand_shardings_.push_back(operand->sharding());
operand->set_sharding(
manual_to_onedevice(operand->shape(), operand->sharding()));
GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
}
}
return Status::OK();
}
Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(),
b_.derived_instructions(hlo));
visiting_hlo_ = nullptr;
b_.set_visiting_hlo(nullptr);
// Revert fake one-device shardings for manually partitioned ops.
if (visiting_hlo_sharding_) {
hlo->set_sharding(*visiting_hlo_sharding_);
GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
for (int64 i = 0; i < hlo->operand_count(); ++i) {
auto operand = hlo->mutable_operand(i);
operand->set_sharding(visiting_hlo_operand_shardings_[i]);
GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
}
visiting_hlo_sharding_.reset();
visiting_hlo_operand_shardings_.clear();
}
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
const int64 dimension = hlo->concatenate_dimension();
if (sharding.tile_assignment().dim(dimension) == 1) {
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(
hlo->CloneWithNewOperands(shard_shape, new_operands));
});
return Status::OK();
}
// If the concatenate dimension is along one of the partitioned dimensions,
// allocate the full output shape, each partition updates its owned region,
// all-reduce across partitions, and then slice its output region.
// We currently don't support subgroup all-reduce along partitions, so more
// than 1 partitioned dimensions is not supported.
if (sharding.tile_assignment().dim(dimension) != num_partitions_) {
return DefaultAction(hlo);
}
// temp_output_shape is the output shape where the concatenate dimension
// is changed to the full (and padded to shard count) dimension size.
auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
auto last_operand_padded_shape =
MakePartitionedShape(hlo->operands().back()->shape(), sharding);
// If the last operand has more padding than the temp_output padding, needs to
// add extra padding to avoid dynamic update slice out of bound.
int last_operand_padding =
last_operand_padded_shape.dimensions(dimension) *
sharding.tile_assignment().dim(dimension) -
hlo->operands().back()->shape().dimensions(dimension);
int temp_output_padding = temp_output_shape.dimensions(dimension) *
sharding.tile_assignment().dim(dimension) -
hlo->shape().dimensions(dimension);
int padding_for_last_operand =
last_operand_padding < temp_output_padding
? 0
: last_operand_padding - temp_output_padding;
temp_output_shape.set_dimensions(
dimension, temp_output_shape.dimensions(dimension) *
sharding.tile_assignment().dim(dimension) +
padding_for_last_operand);
auto temp_output = CreateZero(temp_output_shape, &b_);
// Offset of each operand along the concatenate dimension.
int64 offset = 0;
for (HloInstruction* operand : hlo->operands()) {
auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
std::vector<HloInstruction*> start_indices(
hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(S32))));
start_indices[dimension] =
MultiplyAddDivideOffsetCalculation(
spmd_operand->shape().dimensions(dimension), offset, 1)
.Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_,
&b_)[dimension],
&b_);
temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
temp_output_shape, temp_output, spmd_operand, start_indices));
offset += operand->shape().dimensions(dimension);
}
auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
NewChannel());
SetPartitionedHlo(hlo, [&] {
auto start_indices =
MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
shard_shape.dimensions(dimension), 0, 1)
.Calculate(start_indices[dimension], &b_);
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
});
return Status::OK();
}
namespace {
// Returns whether partitioning in the operand only happens in dimensions with
// gather/scatter slice size 1.
bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
const PartitionedHlo& operand, absl::Span<const int64> index_map,
absl::Span<const int64> slice_size) {
if (operand.sharding().IsTileMaximal()) {
return false;
}
int64 trivial_slice_dims_partitions = 1;
for (int64 dim : index_map) {
if (slice_size[dim] == 1) {
trivial_slice_dims_partitions *=
operand.sharding().tile_assignment().dim(dim);
}
}
return trivial_slice_dims_partitions == operand.sharding().NumTiles();
}
// Returns the min and max for the indices (replicated) in a scatter/gather
// which has the operand partitioned on trivial slice dimensions (slice size 1).
std::pair<HloInstruction*, HloInstruction*>
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
const PartitionedHlo& operand, const PartitionedHlo& replicated_indices,
HloInstruction* partition_id, absl::Span<const int64> index_map,
int64 index_vector_dim, SpmdBuilder* b) {
auto operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand.sharding(), partition_id, b);
// Find the per-dimension index bounds.
std::vector<HloInstruction*> min_indices;
std::vector<HloInstruction*> max_indices;
for (int64 i = 0; i < index_map.size(); ++i) {
int64 dim = index_map[i];
int64 partitions = operand.sharding().tile_assignment().dim(dim);
if (partitions == 1) {
min_indices.push_back(CreateR0WithType<int32>(
replicated_indices.base_shape().element_type(), 0, b));
max_indices.push_back(CreateR0WithType<int32>(
replicated_indices.base_shape().element_type(),
operand.base_shape().dimensions(dim), b));
continue;
}
auto offset = operand_offsets[dim];
if (offset->shape().element_type() !=
replicated_indices.base_shape().element_type()) {
offset = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(),
{}),
offset));
}
min_indices.push_back(offset);
auto partition_size_minus_1 =
CreateR0WithType<int32>(replicated_indices.base_shape().element_type(),
operand.hlo()->shape().dimensions(dim) - 1, b);
max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary(
offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1)));
}
// Broadcast the index bounds to the same shape as the indices.
HloInstruction* broadcast_min;
HloInstruction* broadcast_max;
if (index_vector_dim < replicated_indices.base_shape().rank()) {
// The index vector is an R1, we need to reshape individual bounds to
// [1], and concat them if there are more than one.
for (int64 i = 0; i < min_indices.size(); ++i) {
min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}),
min_indices[i]));
max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}),
max_indices[i]));
}
int64 slice_dims = max_indices.size();
if (slice_dims > 1) {
min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(min_indices[0]->shape().element_type(),
{slice_dims}),
min_indices, 0));
max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate(
min_indices[0]->shape(), max_indices, 0));
}
broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), min_indices[0], {index_vector_dim}));
broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), max_indices[0], {index_vector_dim}));
} else {
CHECK_EQ(max_indices.size(), 1);
broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), min_indices[0], {}));
broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast(
replicated_indices.base_shape(), max_indices[0], {}));
}
return {broadcast_min, broadcast_max};
}
} // namespace
Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
auto scatter = Cast<HloScatterInstruction>(hlo);
auto dnums = scatter->scatter_dimension_numbers();
auto operand = GetPartitionedHlo(scatter->operand(0));
auto indices = GetPartitionedHlo(scatter->operand(1));
auto updates = GetPartitionedHlo(scatter->operand(2));
std::vector<int64> slice_size(operand.base_shape().rank(), 1);
int64 num_update_window_dims = 0;
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = updates.base_shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_scatter_dims;
for (int64 i = 0; i < updates.base_shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.update_window_dims(), i)) {
update_scatter_dims.push_back(i);
}
}
if (operand.sharding().IsTileMaximal()) {
if (!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||
indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) ==
1)) {
auto reduction_opcode = ParseReductionComputation(scatter->to_apply());
if (!reduction_opcode.has_value()) {
return DefaultAction(hlo);
}
HloInstruction* identity;
switch (*reduction_opcode) {
case HloOpcode::kAdd:
case HloOpcode::kOr:
identity = CreateZero(operand.hlo()->shape(), &b_);
break;
case HloOpcode::kMultiply:
case HloOpcode::kAnd:
identity = CreateOne(operand.hlo()->shape(), &b_);
break;
case HloOpcode::kMinimum:
identity = CreateConstant(
operand.hlo()->shape(),
LiteralUtil::MaxValue(hlo->shape().element_type()), &b_);
break;
case HloOpcode::kMaximum:
identity = CreateConstant(
operand.hlo()->shape(),
LiteralUtil::MinValue(hlo->shape().element_type()), &b_);
break;
default:
return DefaultAction(hlo);
}
std::vector<int64> update_dim_to_index_dim(updates.base_shape().rank(),
-1);
std::vector<int64> index_dim_to_update_dim(indices.base_shape().rank(),
-1);
for (int64 i = 0; i < update_scatter_dims.size(); ++i) {
int64 indices_scatter_dim = i < dnums.index_vector_dim() ? i : i + 1;
update_dim_to_index_dim[update_scatter_dims[i]] = indices_scatter_dim;
index_dim_to_update_dim[indices_scatter_dim] = update_scatter_dims[i];
}
auto new_updates_sharding =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
indices.sharding(), index_dim_to_update_dim,
update_dim_to_index_dim);
CHECK(new_updates_sharding.has_value());
updates = updates.Reshard(*new_updates_sharding);
// Update collective_ops_creator and partition_id for partial replicate.
auto collective_ops_creator = collective_ops_creator_;
auto partition_id = partition_id_;
if (indices.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = GroupShardingOnDims(
indices.sharding(),
{indices.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), sharding_grouped.device_groups, &b_);
collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
partition_id = per_group_partitioner_state.partition_id;
}
// To avoid accumulating the initial operand multiple times during
// all-reduce, we use identity operands for all non-zero partitions.
auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(PRED), partition_id));
not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(identity->shape(), PRED),
not_partition_zero, {}));
auto select_operand =
b_.AddInstruction(HloInstruction::HloInstruction::CreateTernary(
identity->shape(), HloOpcode::kSelect, not_partition_zero,
identity, operand.Replicate().hlo()));
auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands(
scatter->shape(), {select_operand, indices.hlo(), updates.hlo()}));
auto all_reduce =
collective_ops_creator.create_cross_partition_all_reduce(
&b_, pscatter, scatter->to_apply(), {}, NewChannel());
all_reduce->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(all_reduce, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
} else {
auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput(
operand.sharding(), *hlo);
// Handle pass through cases if we can use compatible sharding for update.
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
updates = updates.Reshard(*maybe_passthrough);
auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter(
operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(),
scatter->to_apply(), dnums, scatter->indices_are_sorted(),
scatter->unique_indices()));
pscatter->set_sharding(*maybe_passthrough);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, scatter_dims_to_operand_dims, slice_size) &&
ShapeSizeInBytes(updates.base_shape()) <
ShapeSizeInBytes(scatter->shape())) {
// Operand is sharded on trivial slice dims (update slice size 1). We can
// adjust the indices on each partition by subtracting the offsets. Then
// we execute a scatter on full updated indices, and out-of-bound accesses
// will have no effect on the result as guaranteed by the scatter
// semantics.
indices = indices.Reshard(HloSharding::Replicate());
updates = updates.Reshard(HloSharding::Replicate());
HloInstruction* indices_min;
HloInstruction* indices_max_unused;
std::tie(indices_min, indices_max_unused) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, partition_id_, scatter_dims_to_operand_dims,
dnums.index_vector_dim(), &b_);
auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
indices_min));
auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter(
operand.hlo()->shape(), operand.hlo(), adjusted_indices,
updates.hlo(), scatter->to_apply(), dnums,
scatter->indices_are_sorted(), scatter->unique_indices()));
pscatter->set_sharding(operand.sharding());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
// Create a window config to represent the slice.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(hlo->slice_strides(i));
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_padding_low(-hlo->slice_starts(i));
dim->set_padding_high(hlo->slice_limits(i) -
hlo->operand(0)->shape().dimensions(i));
dim->set_base_dilation(1);
}
auto reshard_operand = operand.ReshardAsWindowedInput(
window, sharding,
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
const Shape& operand_shape = reshard_operand->sharded_input->shape();
std::vector<int64> start_indices = hlo->slice_starts();
std::vector<int64> limit_indices = hlo->slice_limits();
std::vector<int64> strides = hlo->slice_strides();
bool need_slice = false;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
auto dim = reshard_operand->shard_window.dimensions(i);
start_indices[i] = -dim.padding_low();
limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
if (start_indices[i] != 0 || strides[i] != 1 ||
limit_indices[i] != operand_shape.dimensions(i)) {
need_slice = true;
}
}
SetPartitionedHlo(hlo, [&] {
if (need_slice) {
auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
return b_.AddInstruction(HloInstruction::CreateSlice(
shard_shape, reshard_operand->sharded_input, start_indices,
limit_indices, strides));
}
auto data = reshard_operand->sharded_input;
// Create a copy so that it will not share the resharding cache.
return b_.AddInstruction(
HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
HloSharding sharding = hlo->sharding();
// Special handling for sort in TopK when first operand partitioined at
// sort dimension.
auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
if (k.has_value()) {
// When the first operand partitioned at sort dimension:
// 1. Partition sort computation to different partitions;
// 2. Slice TopK value and index from different partitions;
// 3. Gather and replicate value and index from different partitions,
// the shape of replicated value and index will be
// [batch_size, ..., partition_count * k, ...];
// 4. Final sort uses replicated value and index from different partitions
// as input.
// GetTupleElement and Slice after the non-partitoned sort won't change
// at this point, as HandleGetTupleElement and HandleSlice will update them.
HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
const int64 sort_dim = sort->sort_dimension();
auto input = hlo->operand(0);
auto index = hlo->operand(1);
const HloSharding& input_sharding = input->sharding();
const int64 partition_count =
input_sharding.tile_assignment().dim(sort_dim);
const int64 input_size = input->shape().dimensions(sort_dim);
const int64 per_partition_size = CeilOfRatio(input_size, partition_count);
const auto element_type = input->shape().element_type();
const auto index_type = index->shape().element_type();
// Partition and pad input and index.
// Pad input with minimal value.
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
CreateFirstWithType(element_type, &b_));
// Pad index with max value.
auto partitioned_index =
GetPartitionedHlo(index)
.Reshard(input_sharding)
.PadWithValue(CreateLastWithType(index_type, &b_));
// Each partition needs to do TopK separately, thus the base shape
// becomes the padded shape.
std::vector<int64> replicated_dimensions(
input->shape().dimensions().begin(), input->shape().dimensions().end());
replicated_dimensions[sort_dim] = per_partition_size * partition_count;
const Shape replicated_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
// Partition original topk to different shards.
auto topk_sharding =
input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
// Get value from first sort.
HloInstruction* value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(0), topk, 0));
HloInstruction* index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
// Slice top K value from the first partitioned sort.
replicated_dimensions[sort_dim] = k.value() * partition_count;
auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
slice_input->set_sharding(input_sharding);
PartitionedHlo partitioned_slice_input(
slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_input =
partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
// Slice top K index from the first parttioned sort.
auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
slice_index->set_sharding(input_sharding);
PartitionedHlo partitioned_slice_index(
slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_index =
partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
// Creates replicated sort to do TopK, the input is value and index pairs
// from all the partitions.
const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
auto final_sort = b_.AddInstruction(HloInstruction::CreateSort(
final_topk_shape, sort_dim,
{replicated_slice_input, replicated_slice_index}, sort->to_apply(),
sort->is_stable()));
final_sort->set_sharding(HloSharding::Replicate()
.GetTupleSharding(final_sort->shape())
.ValueOrDie());
PartitionedHlo replicated_sort(final_sort, final_topk_shape,
MakePartitioningState());
SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
return Status::OK();
}
if (hlo->shape().IsTuple()) {
// Check that all elements are sharded in the same way.
if (hlo->shape().tuple_shapes_size() == 0) {
return DefaultAction(hlo);
}
sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
for (int64 i = 1; i < hlo->operand_count(); ++i) {
if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
return DefaultAction(hlo);
}
}
}
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
for (int64 dim : hlo->dimensions()) {
if (sharding.tile_assignment().dim(dim) > 1) {
return DefaultAction(hlo);
}
}
// Reshard operands to the same as the output.
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
if (hlo->custom_call_target() == "SPMDFullToShardShape") {
// This op switches from auto partitioning to manual partitioning.
auto input_partitioned = GetPartitionedHlo(hlo->operand(0));
if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) {
input_partitioned = input_partitioned.PadWithValue(
CreateR0WithType(hlo->shape().element_type(), 0, &b_));
}
auto input = input_partitioned.hlo();
CHECK(hlo->sharding().IsManual());
CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape()));
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
SetPartitionedHlo(hlo, [&] { return copy; });
return Status::OK();
}
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
// This op switches from manual partitioning to auto partitioning.
auto input = GetPartitionedHlo(hlo->operand(0)).hlo();
CHECK(input->sharding().IsManual());
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
CHECK(ShapeUtil::Compatible(
copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
SetPartitionedHlo(hlo, [&] { return copy; });
return Status::OK();
}
if (hlo->custom_call_target() != "TopK") {
return DefaultAction(hlo);
}
if (!hlo->operand(0)->has_sharding()) {
return DefaultAction(hlo);
}
const HloSharding& sharding = hlo->operand(0)->sharding();
if (sharding.IsTileMaximal() || sharding.IsReplicated()) {
return DefaultAction(hlo);
}
const int64 sort_dim = 1;
const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
if (shard_count <= 1) {
return DefaultAction(hlo);
}
const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0);
const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim);
const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
if (k >= per_partition_size) {
return DefaultAction(hlo);
}
auto input = hlo->operand(0);
const auto element_type = input->shape().element_type();
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
CreateFirstWithType(element_type, &b_));
// Each partition needs to do TopK separately, thus the base shape
// becomes [batch_size, k * shard_count].
const Shape replicated_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(),
{batch_size, k * shard_count}),
ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})});
auto custom_call_sharding =
sharding.GetTupleSharding(replicated_shape).ValueOrDie();
auto shard_shape =
MakePartitionedShape(replicated_shape, custom_call_sharding);
auto topk = b_.AddInstruction(
hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()}));
topk->set_sharding(custom_call_sharding);
// Partition customcall.
PartitionedHlo partitioned_topk(topk, replicated_shape,
MakePartitioningState());
topk = partitioned_topk.hlo();
// Get value from TopK.
HloInstruction* value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(0), topk, 0));
value_gte->set_sharding(sharding);
// Partition GetTupleElement of value.
PartitionedHlo value_partitioned_gte(
value_gte, partitioned_topk.base_shape().tuple_shapes(0),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_value_gte =
value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
// Get index from TopK.
HloInstruction* index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()),
partition_id_));
// Add per partition offset to index, index returned from CustomCall always
// starts from 0.
auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast(
index_gte->shape(),
b_.AddInstruction(HloInstruction::CreateBinary(
partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32,
b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(per_partition_size))))),
{}));
index_gte = b_.AddInstruction(HloInstruction::CreateBinary(
index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset));
index_gte->set_sharding(sharding);
// Parttion GetTupleElement of index.
PartitionedHlo index_partitioned_gte(
index_gte, partitioned_topk.base_shape().tuple_shapes(1),
MakePartitioningState());
// Reshard index to be replicated.
auto replicated_index_gte =
index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo();
// Creates replicated sort to do TopK, the input is value and index pairs
// from all the partitions. The reason to use Sort instead of CustomCall TopK
// is CustomCall only takes value as input. There will be an extra Gather
// to get the correct index if CustomCall is used here.
// Create comparator for the sort.
XlaBuilder b("Sort.Compare");
XlaComputation comparator = CreateScalarComparisonComputation(
"compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt},
&b);
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module,
HloModule::CreateFromProto(comparator.proto(), config));
HloCloneContext context(module_);
auto compare_computation =
module_->DeepCloneComputation(new_module->entry_computation(), &context);
auto sort = b_.AddInstruction(HloInstruction::CreateSort(
replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte},
compare_computation, true));
sort->set_sharding(
HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie());
PartitionedHlo replicated_sort(sort, replicated_shape,
MakePartitioningState());
// Slice value and index from top-k for output.
HloInstruction* sort_value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(),
0));
HloInstruction* sort_index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
1));
// Slice value from final sort.
HloInstruction* slice_sort_value =
SliceFirstK(sort_value_gte, &b_, sort_dim, k);
// Slice index from final sort.
HloInstruction* slice_index_value =
SliceFirstK(sort_index_gte, &b_, sort_dim, k);
auto create_tuple = b_.AddInstruction(
HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
create_tuple->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(),
MakePartitioningState())
.Reshard(hlo->sharding()));
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
std::vector<int64> inverse_dimensions(hlo->shape().rank());
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
inverse_dimensions[hlo->dimensions(i)] = i;
}
auto desired_operand_sharding =
hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
auto operand = GetPartitionedHlo(hlo->operand(0))
.Reshard(desired_operand_sharding)
.hlo();
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(hlo->operand(0));
// The output shape is the source and the operand shape is the target to get
// the aligned sharding for the operand.
auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding(
hlo->shape(), hlo->operand(0)->shape(), hlo->sharding());
if (desired_operand_sharding.has_value()) {
auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
});
return Status::OK();
}
// Check if operand sharding and sharding are both tiled or partial replicate.
// If both of them are partial replicate, check num_replications are the same.
if (operand.sharding().ReplicateOnLastTileDim() !=
sharding.ReplicateOnLastTileDim() ||
(sharding.ReplicateOnLastTileDim() &&
(operand.sharding().tile_assignment().dimensions().back() !=
sharding.tile_assignment().dimensions().back()))) {
return DefaultAction(hlo);
}
// Try use halo exchange for certain split-dim/merge-dims cases.
// ReshapeSharding failed in these cases probably due to uneven partitioning,
// where halo exchange could help. Specifically we check the following
// conditions to detect supported cases:
// 1) Both input and output are partitioned on one dimension.
// 2) The combined size of dimensions before the partitioned dimension are the
// same on input and output. This means we don't need to consider the major
// dimensions.
// 3) Let A = the input size on the partitioned dimension, and
// B = the output size on the partitioned dimension; then
// either A % B == 0 (split dim) or B % A == 0 (merge dims).
auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
return DefaultAction(hlo);
}
int64 input_sharded_dim = *maybe_input_sharded_dim;
int64 output_sharded_dim = *maybe_output_sharded_dim;
// Check that the major dims before the sharded dim have the same total size
// for input and output.
int64 input_major_dims_size = 1;
for (int64 i = 0; i < input_sharded_dim; ++i) {
input_major_dims_size *= operand.base_shape().dimensions(i);
}
int64 output_major_dims_size = 1;
for (int64 i = 0; i < output_sharded_dim; ++i) {
output_major_dims_size *= hlo->shape().dimensions(i);
}
if (input_major_dims_size != output_major_dims_size) {
return DefaultAction(hlo);
}
// Fix potential device ordering mismatch in tile assignment.
Array<int64> new_input_tile_assignment = sharding.tile_assignment();
new_input_tile_assignment.Reshape(
operand.sharding().tile_assignment().dimensions());
auto aligned_sharding =
sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_input_tile_assignment)
: HloSharding::Tile(new_input_tile_assignment);
operand = operand.Reshard(aligned_sharding);
auto replication_count = sharding.ReplicateOnLastTileDim()
? sharding.tile_assignment().dimensions().back()
: 1;
int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim);
auto input_shard_shape =
MakePartitionedShape(operand.base_shape(), operand.sharding());
auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
if (input_dim_size % output_dim_size == 0) {
// Split dim.
int64 split_factor = input_dim_size / output_dim_size;
int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim);
// Use halo exchange to fix misaligned data.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_base_dilation(1);
dim->set_padding_low(0);
if (i == input_sharded_dim) {
dim->set_padding_high(output_shard_size * split_factor *
num_partitions_ / replication_count -
input_dim_size);
} else {
dim->set_padding_high(0);
}
}
auto reshard_operand = operand.ReshardAsWindowedInput(
window, operand.sharding(),
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
CHECK_EQ(
reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
output_shard_size * split_factor);
SetPartitionedHlo(hlo, [&] {
// Do a local reshape.
return b_.AddInstruction(HloInstruction::CreateReshape(
output_shard_shape, reshard_operand->sharded_input));
});
return Status::OK();
} else if (output_dim_size % input_dim_size == 0) {
// Merge dims.
int64 merge_factor = output_dim_size / input_dim_size;
// First reshape locally. (The sharded dimension could include padded data.)
auto tmp_shard_shape = output_shard_shape;
tmp_shard_shape.set_dimensions(
output_sharded_dim,
input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
auto tmp_reshape = b_.AddInstruction(
HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
tmp_reshape->set_metadata(hlo->metadata());
tmp_reshape->set_sharding(hlo->sharding());
auto tmp_full_shape = tmp_shard_shape;
tmp_full_shape.set_dimensions(
output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
num_partitions_ / replication_count);
auto tmp_output =
PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
// Use halo exchange to fix misaligned data.
Window window;
for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_base_dilation(1);
dim->set_padding_low(0);
if (i == output_sharded_dim) {
dim->set_padding_high(output_dim_size -
tmp_shard_shape.dimensions(output_sharded_dim) *
num_partitions_ / replication_count);
} else {
dim->set_padding_high(0);
}
}
auto reshard_output = tmp_output.ReshardAsWindowedInput(
window, sharding,
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
/*mask_invalid_region=*/false);
if (!reshard_output.has_value()) {
return DefaultAction(hlo);
}
TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
CHECK_EQ(
reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
output_shard_shape.dimensions(output_sharded_dim));
SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
return Status::OK();
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&] {
int64 dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
auto iota = b_.AddInstruction(HloInstruction::CreateIota(
MakePartitionedShape(hlo->shape(), sharding), dimension));
if (sharding.tile_assignment().dim(dimension) > 1) {
auto partition_ordinals =
MakeTiledPartitionOrdinals(sharding, partition_id_, &b_);
auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(iota->shape().dimensions(dimension))));
auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
partition_ordinals[dimension], multiplier));
if (iota->shape().element_type() != S32) {
offset = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
}
auto broadcast = b_.AddInstruction(
HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
return b_.AddInstruction(HloInstruction::CreateBinary(
iota->shape(), HloOpcode::kAdd, iota, broadcast));
}
return iota;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
int64 device = hlo->sharding().GetUniqueDevice();
const HloSharding sharding = HloSharding::AssignDevice(device);
std::vector<HloInstruction*> operands;
std::vector<Shape> operand_shapes;
for (const HloInstruction* operand : hlo->operands()) {
operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
operand_shapes.push_back(operand->shape());
}
auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes);
auto on_device = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(device)));
auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device,
ComparisonDirection::kEq));
SpmdBuilder true_b("true_computation", visiting_hlo_);
HloComputation* true_computation;
{
auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, operand_shape, "true_branch_param"));
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < operands.size(); ++i) {
new_operands.push_back(true_b.AddInstruction(
HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i)));
}
auto root = true_b.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
}
SpmdBuilder false_b("false_computation", visiting_hlo_);
HloComputation* false_computation;
{
false_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, operand_shape, "false_branch_param"));
auto root = CreateZero(hlo->shape(), &false_b);
false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
hlo->shape(), pred, operand, true_computation, operand,
false_computation));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
return HandleElementwise(hlo);
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto& operand = GetPartitionedHlo(hlo->operand(0));
// Tiled output.
std::vector<int64> new_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_linear_search(hlo->dimensions(), i)) {
new_dims.push_back(i);
}
}
auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
new_dims),
new_dims);
auto input = operand.Reshard(desired_input_sharding).hlo();
auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(
hlo->CloneWithNewOperands(output_shard_shape, {input}));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
const Literal& literal = hlo->literal();
if (literal.shape().IsTuple() ||
(!hlo->sharding().IsTileMaximal() &&
(!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
!literal.IsAllFirst()))) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&]() {
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
std::vector<int64> start_indices(hlo->shape().rank(), 0);
auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
literal.Slice(start_indices, shard_shape.dimensions())));
*constant->mutable_shape() = shard_shape;
return constant;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (hlo->sharding().tile_assignment().dim(i) != 1 &&
(hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) ||
!hlo->operand(i + 1)->IsConstant() ||
!hlo->operand(i + 1)->literal().IsZero({}))) {
// We currently do not partition the sliced dimensions.
return DefaultAction(hlo);
}
}
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
auto new_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
for (int64 i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
.Reshard(HloSharding::Replicate())
.hlo();
}
SetPartitionedHlo(hlo, [&]() {
auto partitioned_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
partitioned_shape, new_input, new_indices,
partitioned_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
std::vector<int64> partitioned_slice_dims;
std::vector<int64> slice_dims;
std::vector<int64> partitioned_non_slice_dims;
std::vector<int64> partitioned_slice_offsets;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
slice_dims.push_back(i);
if (hlo->sharding().tile_assignment().dim(i) != 1) {
if (!hlo->operand(i + 2)->IsConstant()) {
return DefaultAction(hlo);
}
partitioned_slice_dims.push_back(i);
partitioned_slice_offsets.push_back(
hlo->operand(i + 2)->literal().Get<int>({}));
}
} else if (hlo->sharding().tile_assignment().dim(i) != 1) {
if (!hlo->operand(i + 2)->IsConstant() ||
!hlo->operand(i + 2)->literal().IsZero({})) {
return DefaultAction(hlo);
}
partitioned_non_slice_dims.push_back(i);
}
}
// Handle when there is slice dim partitioned.
if (!partitioned_slice_dims.empty()) {
auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
return b_.AddInstruction(std::move(to_add));
};
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
for (int64 i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
}
// Get partitioned input.
const auto& dus_sharding = hlo->sharding();
const auto& partitioned_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
// Get replicate update.
auto update_sharding = HloSharding::Replicate();
if (!partitioned_non_slice_dims.empty()) {
// Do partial replicate for update if non slice dims are partitioned.
update_sharding =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
slice_dims);
}
HloInstruction* replicate_update =
GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
const auto& update_shape = replicate_update->shape();
const auto& partitioned_shape = partitioned_input->shape();
auto partition_ordinals =
MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
HloInstruction* all_dims_within_partition = add_hlo(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
int dim = partitioned_slice_dims[i];
// Calculate per partition size.
const int64 per_partition_size = partitioned_shape.dimensions(dim);
// Only update within a single partition is supported.
if ((partitioned_slice_offsets[i] / per_partition_size) !=
((partitioned_slice_offsets[i] + update_shape.dimensions(dim) - 1) /
per_partition_size)) {
return DefaultAction(hlo);
}
// within_partition = (offset >= partition_id * per_partition_size) &&
// (offset < (partition_id + 1) * per_partition_size)
const Shape& compare_shape =
ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int>(per_partition_size)));
const Shape& offset_shape = per_partition_size_hlo->shape();
auto partition_offset = add_hlo(HloInstruction::CreateBinary(
offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
per_partition_size_hlo));
// offset >= partition_id * per_partition_size
auto offset_ge = add_hlo(HloInstruction::CreateCompare(
compare_shape, new_indices[dim], partition_offset,
ComparisonDirection::kGe));
// offset < (partition_id + 1) * per_partition_size
auto offset_lt = add_hlo(HloInstruction::CreateCompare(
compare_shape, new_indices[dim],
add_hlo(HloInstruction::CreateBinary(
offset_shape, HloOpcode::kMultiply,
add_hlo(HloInstruction::CreateBinary(
offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
add_hlo(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int>(1))))),
per_partition_size_hlo)),
ComparisonDirection::kLt));
auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
compare_shape, HloOpcode::kAnd, all_dims_within_partition,
update_within_partition));
// Calculate offset.
// slice dim offset =
// within_partition ?
// offset - partition_id * per_partition_size : 0
new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
new_indices[dim]->shape(), HloOpcode::kSelect,
update_within_partition,
add_hlo(HloInstruction::CreateBinary(
new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
partition_offset)),
add_hlo(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
}
// Create dynamic update slice.
auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
partitioned_shape, partitioned_input, replicate_update, new_indices));
SetPartitionedHlo(hlo, [&]() {
// Select if update is needed.
return add_hlo(HloInstruction::CreateTernary(
dus->shape(), HloOpcode::kSelect,
add_hlo(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(dus->shape(), PRED),
all_dims_within_partition, {})),
dus, partitioned_input));
});
return Status::OK();
}
// Partition non slice dims only.
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
auto new_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
auto new_update =
GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
for (int64 i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
}
SetPartitionedHlo(hlo, [&]() {
auto partitioned_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
partitioned_shape, new_input, new_update, new_indices));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
auto gather = Cast<HloGatherInstruction>(hlo);
const auto& dnums = gather->gather_dimension_numbers();
auto operand = GetPartitionedHlo(gather->operand(0));
auto indices = GetPartitionedHlo(gather->operand(1));
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> batch_dims;
for (int64 i = 0; i < gather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
if (operand.sharding().IsTileMaximal()) {
if (!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||
indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) ==
1)) {
auto replicated_operand = operand.Replicate();
TF_ASSIGN_OR_RETURN(
Shape partitioned_output_shape,
ShapeInference::InferGatherShape(replicated_operand.hlo()->shape(),
indices.hlo()->shape(), dnums,
gather->gather_slice_sizes()));
auto pgather = b_.AddInstruction(gather->CloneWithNewOperands(
partitioned_output_shape, {replicated_operand.hlo(), indices.hlo()}));
std::vector<int64> output_dim_to_index_dim(pgather->shape().rank(), -1);
std::vector<int64> index_dim_to_output_dim(indices.base_shape().rank(),
-1);
for (int64 i = 0; i < batch_dims.size(); ++i) {
int64 indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1;
output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim;
index_dim_to_output_dim[indices_batch_dim] = batch_dims[i];
}
auto pgather_sharding =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
indices.sharding(), index_dim_to_output_dim,
output_dim_to_index_dim);
CHECK(pgather_sharding.has_value());
pgather->set_sharding(*pgather_sharding);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
} else {
auto maybe_passthrough =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
operand.sharding(), *hlo);
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough);
std::vector<int64> pslice_sizes(gather->gather_slice_sizes().begin(),
gather->gather_slice_sizes().end());
for (int64 i = 0; i < pslice_sizes.size(); ++i) {
if (operand.sharding().tile_assignment().dim(i) > 1) {
pslice_sizes[i] = operand.hlo()->shape().dimensions(i);
}
}
auto pgather = b_.AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes,
gather->indices_are_sorted()));
pgather->set_sharding(*maybe_passthrough);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, start_index_map, gather->gather_slice_sizes()) &&
ShapeSizeInBytes(gather->shape()) <
ShapeSizeInBytes(gather->operand(0)->shape())) {
indices = indices.Reshard(HloSharding::Replicate());
// Now the operand is partitioned in trivial slice dimensions, and the
// indices are replicated. We execute a gather on partitioned operand,
// with full number of indices, where out-of-bounds indices are clamped,
// and masked out with 0 in the result; then we use all-reduce to combine
// results. Although gather will not get faster, we avoided the need to
// replicate the operand.
HloInstruction* indices_min;
HloInstruction* indices_max;
std::tie(indices_min, indices_max) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, partition_id_, start_index_map,
dnums.index_vector_dim(), &b_);
// Clamp the indices.
auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary(
indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(),
indices_max));
// Adjust the indices by subtracting the offset.
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.base_shape(), HloOpcode::kSubtract, adjusted_indices,
indices_min));
// Gather on adjusted indices.
auto pgather = b_.AddInstruction(HloInstruction::CreateGather(
gather->shape(), operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
// Mask out invalid results.
auto filter = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_min, ComparisonDirection::kLt));
filter = b_.AddInstruction(HloInstruction::CreateBinary(
filter->shape(), HloOpcode::kOr, filter,
b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_max, ComparisonDirection::kGt))));
if (dnums.index_vector_dim() < indices.base_shape().rank()) {
std::vector<int64> reduced_filter_dims;
for (int64 i = 0; i < filter->shape().rank(); ++i) {
if (i != dnums.index_vector_dim()) {
reduced_filter_dims.push_back(filter->shape().dimensions(i));
}
}
filter = b_.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter,
CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()},
MakeBinaryAdd(PRED, module_)));
}
std::vector<int64> batch_dims;
for (int64 i = 0; i < pgather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter,
batch_dims));
auto filtered = b_.AddInstruction(HloInstruction::CreateTernary(
pgather->shape(), HloOpcode::kSelect, broadcast_filter,
CreateZero(pgather->shape(), &b_), pgather));
// Combine from different partitions.
auto collective_ops_creator = collective_ops_creator_;
if (operand.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = GroupShardingOnDims(
operand.sharding(),
{operand.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), sharding_grouped.device_groups, &b_);
collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
}
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
&b_, filtered,
MakeBinaryAdd(filtered->shape().element_type(), module_), {},
NewChannel());
ar->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
return DefaultAction(hlo);
}
Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
const auto& tuple = GetPartitionedHlo(hlo->operand(0));
auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
tuple.hlo(), hlo->tuple_index()));
SetPartitionedHlo(hlo, [&]() {
const auto source_sharding = tuple.sharding().GetSubSharding(
tuple.base_shape(), {hlo->tuple_index()});
gte->set_sharding(source_sharding);
PartitionedHlo source_partitioned_gte(gte, hlo->shape(),
MakePartitioningState());
return source_partitioned_gte.Reshard(hlo->sharding()).hlo();
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
if (ShapeUtil::GetLeafCount(shape) == 0) {
// TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
// requires one element for an empty tuple, but leaf-count number of
// elements for non-empty tuple. So if it has a nested empty tuple, we
// cannot invoke GetSubSharding() since it expects a sharding for the empty
// tuple. This is a workaround for that case.
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(
HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
});
return Status::OK();
}
auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
auto shard_shape = MakePartitionedShape(shape, sharding);
if (EvenlyPartitions(shape, sharding)) {
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateInfeed(
shard_shape, token, hlo->infeed_config()));
});
return Status::OK();
}
if (hlo->sharding().HasUniqueDevice()) {
return HandleSingleDevice(hlo);
}
// Create a branch for each unique partitioned shape.
std::vector<Shape> per_branch_partitioned_shapes;
std::vector<int32> conditional_branch_indices(num_partitions_);
for (int64 i = 0; i < num_partitions_; ++i) {
auto partitioned_shape =
MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
int64 matching_existing_index = 0;
for (; matching_existing_index < per_branch_partitioned_shapes.size();
++matching_existing_index) {
if (ShapeUtil::Compatible(
partitioned_shape,
per_branch_partitioned_shapes[matching_existing_index])) {
break;
}
}
if (matching_existing_index < per_branch_partitioned_shapes.size()) {
conditional_branch_indices[i] = matching_existing_index;
} else {
conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
}
}
HloInstruction* branch_index;
if (per_branch_partitioned_shapes.size() == num_partitions_) {
// Use partition ID as the branch index if each partition has its own
// branch.
branch_index = partition_id_;
// PartitionId's output is U32 but conditional requires S32.
if (branch_index->shape().element_type() != S32) {
branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(branch_index->shape(), S32),
branch_index));
}
} else {
// Otherwise, use a constant table to look up the branch index.
auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
{1}));
branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {}), branch_index));
}
std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
for (int64 i = 0; i < branches.size(); ++i) {
SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, token->shape(), "infeed_token_param"));
auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
pad_infeed = [&](const ShapeIndex& index,
HloInstruction* infeed_element) -> HloInstruction* {
if (index == ShapeIndex({1})) {
// Token.
return infeed_element;
}
const Shape& element_shape =
ShapeUtil::GetSubshape(infeed->shape(), index);
if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
std::vector<HloInstruction*> padded_elements(
element_shape.tuple_shapes_size());
for (int64 i = 0; i < padded_elements.size(); ++i) {
auto sub_index = index;
sub_index.push_back(i);
padded_elements[i] = pad_infeed(
sub_index,
branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
i)));
}
return branch_b.AddInstruction(
HloInstruction::CreateTuple(padded_elements));
}
const Shape& pad_shape =
ShapeUtil::GetSubshape(shard_shape, ShapeIndexView(index, 1));
if (ShapeUtil::Compatible(element_shape, pad_shape)) {
return infeed_element;
}
if (element_shape.IsArray()) {
CHECK(pad_shape.IsArray());
return PadToShape(infeed_element, pad_shape, &branch_b);
}
CHECK(element_shape.IsTuple());
CHECK(element_shape.tuple_shapes().empty());
return CreateZero(pad_shape, &branch_b);
};
pad_infeed({}, infeed);
}
branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
branches, std::vector<HloInstruction*>(branches.size(), token)));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
// Create a window config to represent the pad.
Window window;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
const auto& pd = hlo->padding_config().dimensions(i);
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_window_reversal(false);
dim->set_padding_low(pd.edge_padding_low());
dim->set_padding_high(pd.edge_padding_high());
dim->set_base_dilation(pd.interior_padding() + 1);
}
auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
.Reshard(HloSharding::Replicate())
.hlo();
auto reshard_operand =
lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
/*mask_invalid_region=*/false);
if (!reshard_operand.has_value()) {
return DefaultAction(hlo);
}
PaddingConfig sharded_padding_config;
bool need_pad = false;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
auto dim = sharded_padding_config.add_dimensions();
const auto& wd = reshard_operand->shard_window.dimensions(i);
dim->set_edge_padding_low(wd.padding_low());
dim->set_edge_padding_high(wd.padding_high());
dim->set_interior_padding(wd.base_dilation() - 1);
if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
wd.base_dilation() != 1) {
need_pad = true;
}
}
auto sharded_pad = reshard_operand->sharded_input;
if (need_pad) {
TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
ShapeInference::InferPadShape(sharded_pad->shape(),
replicated_rhs->shape(),
sharded_padding_config));
sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
sharded_pad, replicated_rhs,
sharded_padding_config));
}
SetPartitionedHlo(hlo, [&]() {
if (!reshard_operand->dynamic_slice_index_on_output) {
return sharded_pad;
}
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_pad,
*reshard_operand->dynamic_slice_index_on_output,
shard_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
SetPartitionedHlo(hlo, [&]() {
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
hlo->parameter_number(), shard_shape, "param"));
if (hlo->parameter_replicated_at_leaf_buffers()) {
new_param->set_parameter_replicated_at_leaf_buffers(
*hlo->parameter_replicated_at_leaf_buffers());
}
return new_param;
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
int64 input_count = 1;
auto per_input_sharding = hlo->sharding();
if (hlo->shape().IsTuple()) {
input_count = hlo->shape().tuple_shapes_size();
CHECK_GT(input_count, 0);
per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
}
std::vector<PartitionedHlo> inputs;
std::vector<HloInstruction*> inits;
std::vector<int64> preserved_dims;
for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
if (!absl::c_linear_search(hlo->dimensions(), i)) {
preserved_dims.push_back(i);
}
}
for (int64 operand_id = 0; operand_id < input_count; ++operand_id) {
inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
.Reshard(HloSharding::Replicate())
.hlo());
inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
if (operand_id > 0) {
// Make sure all operands are sharded in the same way.
inputs.back() = inputs.back().Reshard(inputs[0].sharding());
}
if (!inputs[0].sharding().IsTileMaximal()) {
inputs.back() =
inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
/*skipped_dims=*/preserved_dims);
}
}
std::vector<Shape*> new_operand_shapes(input_count * 2);
for (int64 i = 0; i < input_count; ++i) {
new_operand_shapes[i] = inputs[i].hlo()->mutable_shape();
new_operand_shapes[i + input_count] = inits[i]->mutable_shape();
}
// Create the shard shape of the reduce result.
TF_ASSIGN_OR_RETURN(
auto reduce_shape,
ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
hlo->to_apply()->ComputeProgramShape()));
std::vector<HloInstruction*> input_hlos(input_count);
for (int64 i = 0; i < input_count; ++i) {
input_hlos[i] = inputs[i].hlo();
}
auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
local_reduce->set_metadata(hlo->metadata());
SetPartitionedHlo(hlo, [&]() {
HloInstruction* reduce = local_reduce;
const bool reduce_sharded_dimension =
!inputs[0].sharding().IsTileMaximal() &&
absl::c_any_of(hlo->dimensions(), [&](int64 i) {
return inputs[0].sharding().tile_assignment().dim(i) > 1;
});
if (reduce_sharded_dimension) {
if (inputs[0].sharding().ReplicateOnLastTileDim()) {
preserved_dims.push_back(inputs[0].base_shape().rank());
}
auto grouped = GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
auto grouped_state = CreatePerGroupPartitioningState(
inputs[0].state(), grouped.device_groups, &b_);
if (local_reduce->shape().IsArray()) {
reduce = grouped_state.collective_ops_creator
.create_cross_partition_all_reduce(
&b_, local_reduce, hlo->to_apply(), {}, NewChannel());
} else {
std::vector<HloInstruction*> all_gathered_partial_results(input_count);
for (int64 i = 0; i < input_count; ++i) {
auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
i));
auto expanded_shape = input_hlos[i]->shape();
auto all_gather_shape = input_hlos[i]->shape();
for (int64 dim : hlo->dimensions()) {
expanded_shape.set_dimensions(dim, 1);
all_gather_shape.set_dimensions(
dim, inputs[0].sharding().tile_assignment().dim(dim));
}
auto reshape = b_.AddInstruction(
HloInstruction::CreateReshape(expanded_shape, gte));
// Replicate per group.
reshape->set_sharding(grouped.sharding);
all_gathered_partial_results[i] =
PartitionedHlo(reshape, all_gather_shape, grouped_state)
.Replicate()
.hlo();
}
reduce = b_.AddInstruction(HloInstruction::CreateReduce(
reduce_shape, all_gathered_partial_results, inits,
hlo->dimensions(), hlo->to_apply()));
}
}
auto sharding = hlo_sharding_util::RemoveShapeDimensions(
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
inputs[0].sharding(), hlo->dimensions()),
hlo->dimensions());
if (local_reduce->shape().IsArray()) {
reduce->set_sharding(sharding);
} else {
reduce->set_sharding(HloSharding::Tuple(
reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
}
return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
auto reverse = Cast<HloReverseInstruction>(hlo);
if (reverse->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(reverse->operand(0))
.Reshard(hlo_sharding_util::ReverseSharding(
reverse->sharding(), reverse->dimensions()));
auto left_padded_operand =
HaloExchangeToPadOnLeft(operand, reverse->dimensions());
if (!left_padded_operand) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
left_padded_operand->shape(), {left_padded_operand}));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
// Shardings for the body parameter, body root, and cond parameter must be
// the same, and the condition root must be replicated so that all partitions
// follow the same control flow.
hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_condition(),
HloSharding::Replicate(),
next_channel_id_, logger_)
.status());
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(hlo->while_body(), sharding,
next_channel_id_, logger_)
.status());
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateWhile(
MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
hlo->while_body(),
GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
std::vector<HloInstruction*> branch_args;
for (int64 i = 0; i < hlo->branch_count(); ++i) {
HloComputation* computation = hlo->branch_computation(i);
// Shardings of the branch computation parameter and its argument must be
// the same.
computation->parameter_instruction(0)->set_sharding(
hlo->operand(i + 1)->sharding());
branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
}
// The root of the branch computations must follow the sharding of the
// conditional instruction.
for (int64 i = 0; i < hlo->branch_count(); ++i) {
HloComputation* computation = hlo->branch_computation(i);
TF_RETURN_IF_ERROR(partitioner_
->PartitionComputation(computation, hlo->sharding(),
next_channel_id_, logger_)
.status());
}
// We replicate the predicate of the conditional (the first operand) so that
// all partitions follow the same control flow.
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateConditional(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
GetPartitionedHlo(hlo->operand(0))
.Reshard(HloSharding::Replicate())
.hlo(),
hlo->called_computations(), branch_args));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
if (hlo->sharding().HasUniqueDevice()) {
return HandleSingleDevice(hlo);
}
const auto& sharding = hlo->sharding();
const Shape& shape = hlo->operand(0)->shape();
auto partitioned_operand =
GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
const auto& shard_shape = partitioned_operand.hlo()->shape();
const auto& operand = partitioned_operand.hlo();
auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
if (EvenlyPartitions(shape, sharding)) {
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateOutfeed(
operand->shape(), operand, token, hlo->outfeed_config()));
});
return Status::OK();
}
// Create a branch for each unique partitioned shape.
std::vector<Shape> per_branch_partitioned_shapes;
std::vector<int32> conditional_branch_indices(num_partitions_);
for (int64 i = 0; i < num_partitions_; ++i) {
auto partitioned_shape =
MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
int64 matching_existing_index = 0;
for (; matching_existing_index < per_branch_partitioned_shapes.size();
++matching_existing_index) {
if (ShapeUtil::Compatible(
partitioned_shape,
per_branch_partitioned_shapes[matching_existing_index])) {
break;
}
}
if (matching_existing_index < per_branch_partitioned_shapes.size()) {
conditional_branch_indices[i] = matching_existing_index;
} else {
conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
}
}
// Get branch index for this partition.
HloInstruction* branch_index;
if (per_branch_partitioned_shapes.size() == num_partitions_) {
// Use partition ID as the branch index if each partition has its own
// branch.
branch_index = partition_id_;
// PartitionId's output is U32 but conditional requires S32.
if (branch_index->shape().element_type() != S32) {
branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(branch_index->shape(), S32),
branch_index));
}
} else {
// Otherwise, use a constant table to look up the branch index.
auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
{1}));
branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {}), branch_index));
}
// Create conditional for the outfeed.
std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
for (int64 i = 0; i < branches.size(); ++i) {
SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
// Create tuple param within the branch.
auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
"outfeed_token_param"));
auto outfeed_data = branch_b.AddInstruction(
HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
auto outfeed_token = branch_b.AddInstruction(
HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
slice_outfeed =
[&](const ShapeIndex& index,
HloInstruction* outfeed_operand) -> HloInstruction* {
// Get outfeed element shape.
const Shape& element_shape =
ShapeUtil::GetSubshape(outfeed_data->shape(), index);
// Recursively call slice_outfeed for tuple shapes.
if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
std::vector<HloInstruction*> slice_elements(
element_shape.tuple_shapes_size());
for (int64 i = 0; i < slice_elements.size(); ++i) {
auto sub_index = index;
sub_index.push_back(i);
slice_elements[i] = slice_outfeed(
sub_index,
branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
i)));
}
return branch_b.AddInstruction(
HloInstruction::CreateTuple(slice_elements));
}
// Get the slice shape.
const Shape& slice_shape = ShapeUtil::GetSubshape(
per_branch_partitioned_shapes[i], ShapeIndexView(index));
if (ShapeUtil::Compatible(element_shape, slice_shape)) {
return outfeed_operand;
}
// Slice out useful data.
if (element_shape.IsArray()) {
CHECK(slice_shape.IsArray());
std::vector<int64> start_indices(slice_shape.rank(), 0);
std::vector<int64> slice_strides(slice_shape.rank(), 1);
return branch_b.AddInstruction(HloInstruction::CreateSlice(
slice_shape, outfeed_operand, start_indices,
slice_shape.dimensions(), slice_strides));
}
CHECK(element_shape.IsTuple());
CHECK(element_shape.tuple_shapes().empty());
return outfeed_operand;
};
outfeed_data = slice_outfeed({}, outfeed_data);
}
branch_b.AddInstruction(HloInstruction::CreateOutfeed(
per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
hlo->outfeed_config()));
branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
token->shape(), branch_index, branches,
std::vector<HloInstruction*>(
branches.size(),
b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
if (hlo->sharding().HasUniqueDevice()) {
return HandleSingleDevice(hlo);
}
if (hlo->sharding().IsReplicated()) {
SetPartitionedHlo(hlo, [&] {
// Run on a single device (0) and distribute the data to all other cores.
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
.Reshard(HloSharding::AssignDevice(0))
.hlo());
}
auto clone = b_.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(HloSharding::AssignDevice(0));
return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
.Reshard(HloSharding::Replicate())
.hlo();
});
return Status::OK();
}
TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
// Replicate the operands and run partitioned Rng on all devices.
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
.Reshard(HloSharding::Replicate())
.hlo());
}
if (!hlo->sharding().ReplicateOnLastTileDim()) {
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateRng(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
hlo->random_distribution(), new_operands));
});
} else {
std::vector<int64> group_dims(
hlo->sharding().tile_assignment().num_dimensions() - 1);
std::iota(group_dims.begin(), group_dims.end(), 0);
auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims);
auto per_group_state = CreatePerGroupPartitioningState(
MakePartitioningState(), sharding_grouped.device_groups, &b_);
auto rng = b_.AddInstruction(HloInstruction::CreateRng(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
hlo->random_distribution(), new_operands));
rng->set_sharding(HloSharding::AssignDevice(0));
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(rng, rng->shape(), per_group_state)
.Replicate()
.hlo();
});
}
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
// TODO(b/73062247) Variadic reduce window not yet supported in partitioner.
if (hlo->shape().IsTuple()) {
return DefaultAction(hlo);
}
auto& operand = GetPartitionedHlo(hlo->operand(0));
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
// Replicate init
auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1))
.Reshard(HloSharding::Replicate());
auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
hlo->window(), hlo->sharding(), replicated_init.hlo());
if (!resharded_operand_and_window.has_value()) {
return DefaultAction(hlo);
}
TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
ShapeInference::InferReduceWindowShape(
resharded_operand_and_window->sharded_input->shape(),
replicated_init.hlo()->shape(),
resharded_operand_and_window->shard_window,
hlo->to_apply()->ComputeProgramShape()));
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
*sharded_rw_shape.mutable_layout() = shard_shape.layout();
SetPartitionedHlo(hlo, [&]() {
auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow(
sharded_rw_shape, resharded_operand_and_window->sharded_input,
replicated_init.hlo(), resharded_operand_and_window->shard_window,
hlo->to_apply()));
if (!resharded_operand_and_window->dynamic_slice_index_on_output
.has_value()) {
CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()));
return sharded_rw;
}
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_rw,
*resharded_operand_and_window->dynamic_slice_index_on_output,
shard_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
auto operand = GetPartitionedHlo(hlo->operand(0));
auto source = GetPartitionedHlo(hlo->mutable_operand(1));
if (hlo->sharding() != operand.sharding()) {
operand = operand.Reshard(hlo->sharding());
}
if (hlo->sharding() != source.sharding()) {
source = source.Reshard(hlo->sharding());
}
// For F32 and BF16 types, we can use NaN padding to workaround the issue with
// low/high padding, since comparison will return false with NaN input.
if (hlo->shape().element_type() != F32 &&
hlo->shape().element_type() != BF16) {
return DefaultAction(hlo);
}
auto select = hlo->called_computations()[0];
auto select_root = select->root_instruction();
if (select_root->opcode() != HloOpcode::kCompare ||
select_root->operand(0)->opcode() != HloOpcode::kParameter ||
select_root->operand(1)->opcode() != HloOpcode::kParameter ||
select_root->operand(0)->parameter_number() +
select_root->operand(1)->parameter_number() !=
1) {
return DefaultAction(hlo);
}
float float_pad_value;
if (select_root->comparison_direction() == ComparisonDirection::kGe ||
select_root->comparison_direction() == ComparisonDirection::kGt) {
if (select_root->operand(0)->parameter_number() == 0) {
float_pad_value = -std::numeric_limits<float>::infinity();
} else {
float_pad_value = std::numeric_limits<float>::infinity();
}
} else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
select_root->comparison_direction() == ComparisonDirection::kLt) {
if (select_root->operand(0)->parameter_number() == 0) {
float_pad_value = std::numeric_limits<float>::infinity();
} else {
float_pad_value = -std::numeric_limits<float>::infinity();
}
} else {
return DefaultAction(hlo);
}
auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
hlo->shape().element_type() == BF16
? LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(float_pad_value))
: LiteralUtil::CreateR0<float>(float_pad_value)));
// Replicate init
auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
.Reshard(HloSharding::Replicate());
auto partition_ordinals =
MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
// The first window for each dimension that overlaps with the shard area.
std::vector<MultiplyAddDivideOffsetCalculation> first_window(
hlo->shape().rank());
// The first window for each dimension that goes beyond with the shard area.
std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
hlo->shape().rank());
std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
auto unpadded_data_shard_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
auto unpadded_source_shard_shape =
MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
auto source_shard_hlo = source.hlo();
auto data_shard_hlo = operand.hlo();
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
int64 shard_count = hlo->sharding().tile_assignment().dim(i);
if (shard_count == 1) {
continue;
}
// If stride > window_size, there will be gaps between windows. These gaps
// will also exist in the output, so we keep them during halo exchange.
//
// TODO(yuanzx): This could introduce overhead if partitions start at
// different offsets in a gap.
auto wd = hlo->window().dimensions(i);
if (wd.stride() > wd.size()) {
wd.set_size(wd.stride());
}
// shard_size * i < stride * k - pad_low + window_size =>
// k > (shard_size * i + pad_low - window_size) / stride =>
// first_k == (shard_size * i + pad_low - window_size + stride) / stride
first_window[i] = MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i),
wd.padding_low() - wd.size() + wd.stride(), wd.stride());
// shard_size * (i + 1) <= stride * k - pad_low =>
// k >= (shard_size * i + shard_size + pad_low) / stride =>
// limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
// stride
limit_window[i] = MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i),
unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
wd.stride() - 1,
wd.stride());
source_left_halo_sizes[i] =
MultiplyAddDivideOffsetCalculation(
unpadded_source_shard_shape.dimensions(i), 0, 1) -
first_window[i];
source_right_halo_sizes[i] =
limit_window[i] - MultiplyAddDivideOffsetCalculation(
unpadded_source_shard_shape.dimensions(i),
unpadded_source_shard_shape.dimensions(i), 1);
data_left_halo_sizes[i] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
OffsetCalculation(
HloOpcode::kMultiply, first_window[i],
MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
data_right_halo_sizes[i] =
OffsetCalculation(
HloOpcode::kMultiply, limit_window[i],
MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
unpadded_data_shard_shape.dimensions(i),
unpadded_data_shard_shape.dimensions(i) + wd.stride() +
wd.padding_low() - wd.size(),
1));
int64 max_windows =
(limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
auto first_window_hlo =
first_window[i].Calculate(partition_ordinals[i], &b_);
// Padding on the source is filled with the init value so they do not change
// the data on overlapping windows.
auto resharded_source = ExchangeHaloAndGetValidData(
source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
source_right_halo_sizes[i], 0,
limit_window[i].Calculate(shard_count - 1), max_windows, i,
hlo->sharding(), first_window_hlo, replicated_init.hlo(),
partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
if (!resharded_source) {
return DefaultAction(hlo);
}
source_shard_hlo = *resharded_source;
auto offset_start_in_data =
MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
.Calculate(first_window_hlo, &b_);
int64 padded_data_size =
(limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
wd.size();
int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
auto resharded_data = ExchangeHaloAndGetValidData(
data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
if (!resharded_data) {
return DefaultAction(hlo);
}
data_shard_hlo = *resharded_data;
}
Window window_on_shard = hlo->window();
for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
int64 shard_count = hlo->sharding().tile_assignment().dim(i);
if (shard_count == 1) {
continue;
}
auto reshard_wd = window_on_shard.mutable_dimensions(i);
// The shards are already explicitly padded.
reshard_wd->set_padding_low(0);
reshard_wd->set_padding_high(0);
}
auto sharded_select_and_scatter =
b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
source_shard_hlo, replicated_init.hlo(),
hlo->called_computations()[1]));
SetPartitionedHlo(hlo, [&]() {
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
shard_shape)) {
return sharded_select_and_scatter;
}
auto zero = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) {
if (hlo->sharding().tile_assignment().dim(i) == 1) {
continue;
}
int64 pad_low = hlo->window().dimensions(i).padding_low();
auto left_halo_size =
data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
slice_offsets[i] = left_halo_size;
} else {
auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
ComparisonDirection::kEq));
auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(pad_low)));
slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
left_halo_size));
}
}
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_select_and_scatter, slice_offsets,
shard_shape.dimensions()));
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
std::vector<HloInstruction*> new_operands;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i))
.Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
.hlo());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
});
return Status::OK();
}
StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
HloComputation* computation, const HloSharding& root_sharding) {
VLOG(2) << "Partitioning computation " << computation->name() << " for "
<< num_replicas_ << " replicas and " << num_partitions_
<< " partitions";
TF_RETURN_IF_ERROR(computation->Accept(this));
HloModule* module = computation->parent();
auto new_root =
GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
auto new_computation =
module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
TF_RETURN_IF_ERROR(DoCodeMotionForWindowedDotGeneralLoops(new_computation));
// Replace the original computation with the new SPMD computation.
std::unordered_map<HloComputation*, HloComputation*> replacement;
replacement[computation] = new_computation;
module->ReplaceComputations(replacement);
return changed_;
}
Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
return Unimplemented(
"PartitionId instruction is not supported for SPMD partitioning since "
"the meaning is ambiguous -- whether the instruction is replicated or "
"the data is replicated, and if the latter which data is replicated.");
}
SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
int64 num_replicas) {
return {
[](SpmdBuilder* b) {
return b->AddInstruction(HloInstruction::CreatePartitionId());
},
[num_replicas, num_partitions](
SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id) {
if (partition_subgroups.size() <= 1) {
std::vector<ReplicaGroup> groups(num_replicas);
// TODO(yuanzx): Unify subgroup definition with AllToAll.
for (int64 i = 0; i < num_replicas; ++i) {
groups[i].add_replica_ids(i);
}
return b->AddInstruction(HloInstruction::CreateAllReduce(
operand->shape(), {operand}, reduction, groups,
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/false));
}
std::vector<ReplicaGroup> device_groups;
device_groups.reserve(partition_subgroups.size() * num_replicas);
for (int64 i = 0; i < num_replicas; ++i) {
for (const auto& pgroup : partition_subgroups) {
device_groups.emplace_back();
for (int64 pid : pgroup) {
device_groups.back().add_replica_ids(i * num_partitions + pid);
}
}
}
return b->AddInstruction(HloInstruction::CreateAllReduce(
operand->shape(), {operand}, reduction, device_groups,
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/true));
},
[](SpmdBuilder* b, HloInstruction* operand,
std::vector<std::pair<int64, int64>>& src_dst_pairs,
int64 channel_id) {
return b->AddInstruction(HloInstruction::CreateCollectivePermute(
operand->shape(), operand, src_dst_pairs, channel_id));
},
[](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, absl::optional<int64> split_dimension) {
std::vector<Shape> shapes(operands.size(), operands[0]->shape());
const Shape output_shape = (shapes.size() == 1)
? shapes[0]
: ShapeUtil::MakeTupleShape(shapes);
std::vector<ReplicaGroup> groups(partition_subgroups.size());
for (int64 i = 0; i < groups.size(); ++i) {
for (int64 id : partition_subgroups[i]) {
groups[i].add_replica_ids(id);
}
}
return b->AddInstruction(HloInstruction::CreateAllToAll(
output_shape, operands, groups,
/*constrain_layout=*/false, channel_id, split_dimension));
},
[num_replicas, num_partitions](
SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, int64 all_gather_dimension) {
std::vector<ReplicaGroup> device_groups;
device_groups.reserve(partition_subgroups.size() * num_replicas);
for (int64 i = 0; i < num_replicas; ++i) {
for (const auto& pgroup : partition_subgroups) {
device_groups.emplace_back();
for (int64 pid : pgroup) {
device_groups.back().add_replica_ids(i * num_partitions + pid);
}
}
}
return b->AddInstruction(HloInstruction::CreateAllGather(
ag_shape, operand, all_gather_dimension, device_groups,
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/true));
},
};
}
SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
SpmdPartitionerOptions options)
: SpmdPartitioner(
num_partitions, num_replicas, std::move(options),
GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
HloInstruction* SpmdPartitioner::AllGatherShards(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64 channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator) {
CHECK(!sharding.IsTileMaximal());
// Add one leading dimension to gather all partitions.
std::vector<int64> shape;
shape.push_back(1);
for (int64 dim : operand->shape().dimensions()) {
shape.push_back(dim);
}
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, selected_dims);
shape[0] = partition_subgroups[0].size();
auto result = collectives_creator.create_cross_partition_all_gather(
b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, channel_id, /*all_gather_dimension=*/0);
// If n > 1 dimensions are partitioned, split the leading dimension to n.
std::vector<int64> tiled_dims;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (sharding.tile_assignment().dim(i) > 1 &&
absl::c_linear_search(selected_dims, i)) {
tiled_dims.push_back(i);
}
}
if (tiled_dims.size() > 1) {
std::vector<int64> split_dim_shape;
split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
for (int64 i : tiled_dims) {
split_dim_shape.push_back(sharding.tile_assignment().dim(i));
}
for (int64 dim : operand->shape().dimensions()) {
split_dim_shape.push_back(dim);
}
result = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
result));
}
// Transpose the gathered dimensions to next to their corresponding
// partitioned dimensions.
std::vector<int64> xpose_permutation(result->shape().rank());
int64 split_dims_added = 0;
for (int64 i = 0; i < xpose_permutation.size(); ++i) {
if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
!absl::c_linear_search(selected_dims, i - split_dims_added)) {
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
} else {
xpose_permutation[i] = split_dims_added;
xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
split_dims_added++;
i++;
}
}
result = b->AddInstruction(HloInstruction::CreateTranspose(
ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
.ValueOrDie(),
result, xpose_permutation));
// Reshape to the desired shape.
auto ag_shape = operand->shape();
for (int64 i : tiled_dims) {
ag_shape.set_dimensions(
i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
}
result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
return result;
}
StatusOr<bool> SpmdPartitioner::PartitionComputation(
HloComputation* computation, const HloSharding& root_sharding,
int64* next_channel_id, SpmdLogger* logger) {
auto visitor =
CreateVisitor(computation, num_partitions_, num_replicas_,
collective_ops_creator_, next_channel_id, logger, options_);
return visitor->DoPartition(computation, root_sharding);
}
std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options) {
return absl::make_unique<SpmdPartitioningVisitor>(
computation, num_partitions, num_replicas, collective_ops_creator,
next_channel_id, logger, std::move(options), this);
}
StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
TF_RETURN_IF_ERROR(PreprocessSharding(module));
XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
*module, options_.report_instruction_count));
// Add the parameters' and output's shardings to the module.
std::vector<HloSharding> entry_params_shardings;
for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) {
auto param = module->entry_computation()->parameter_instruction(i);
CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
entry_params_shardings.push_back(param->sharding());
}
module->set_spmd_parameters_shardings(entry_params_shardings);
auto entry_root = module->entry_computation()->root_instruction();
CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
module->set_spmd_output_sharding(entry_root->sharding());
FlattenCallGraph flatten;
TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
SpmdLogger logger(options_.report_instruction_count);
auto program_shape = module->entry_computation()->ComputeProgramShape();
int64 next_channel_id = hlo_query::NextChannelId(*module);
// Copy the root sharding since the partitioner visitor may temporarily change
// the sharding to work around manual sharding.
HloSharding root_sharding = entry_root->sharding();
TF_ASSIGN_OR_RETURN(
bool partition_changed,
PartitionComputation(module->entry_computation(), root_sharding,
&next_channel_id, &logger));
changed |= partition_changed;
// For the entry computation, make sure that the root instruction and the
// parameters preserve their signatures.
auto new_program_shape = module->entry_computation()->ComputeProgramShape();
if (!options_.allow_module_signature_change) {
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
program_shape.result(), new_program_shape.result()))
<< "Result shape changed for the entry computation";
TF_RET_CHECK(program_shape.parameters_size() ==
new_program_shape.parameters_size())
<< "Parameter count changed for the entry computation";
for (int64 i = 0; i < program_shape.parameters_size(); ++i) {
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
program_shape.parameters(i), new_program_shape.parameters(i)))
<< "Parameter shape changed for the entry computation";
}
} else {
const auto& old_entry_layout = module->entry_computation_layout();
// Shapes can change but the layout should still remain the same.
for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) {
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
old_entry_layout.parameter_shape(i),
new_program_shape.mutable_parameters(i)));
}
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
old_entry_layout.result_shape(), new_program_shape.mutable_result()));
HloModuleConfig config = module->config();
*config.mutable_entry_computation_layout() =
ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
module->set_config(config);
}
XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
*module, options_.report_instruction_count));
XLA_VLOG_LINES(1, logger.MakeReport());
if (changed) {
HloPassPipeline pass("spmd-cleanup");
pass.AddPass<TupleSimplifier>();
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pass.AddPass<FlattenCallGraph>();
TF_RETURN_IF_ERROR(pass.Run(module).status());
}
TF_RETURN_IF_ERROR(ClearShardingAttributes(module));
return changed;
}
Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
for (HloComputation* computation : module->computations()) {
for (HloInstruction* hlo : computation->instructions()) {
if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
TF_RET_CHECK(hlo->has_sharding())
<< "Side-effect HLO must have sharding: " << hlo->ToString();
TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
hlo->opcode() == HloOpcode::kInfeed ||
hlo->opcode() == HloOpcode::kOutfeed)
<< "Non-infeed side-effect HLO cannot have a replicated sharding:"
<< hlo->ToString();
}
// For unassigned HLOs, annotate with replicated sharding.
//
// Among side-effecting ops, only Rng is allowed to omit the annotation.
// In that case, we currently force it to run on core 0, since we don't
// support partitioning or replicating the Rng op (the values depend on
// the seed provided to each device).
//
// TODO(hyouklee): Should we also convert single-device shardings (without
// side-effects) into replicated?
if (!hlo->has_sharding()) {
if (hlo->opcode() == HloOpcode::kRng) {
hlo->set_sharding(HloSharding::AssignDevice(0));
} else {
hlo->set_sharding(
HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
}
} else if (!hlo->sharding().IsTileMaximal() &&
!hlo->sharding().IsManual()) {
std::vector<int64> available(num_partitions_);
std::iota(available.begin(), available.end(), 0);
TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
hlo->sharding(), available)
.size())
<< "num_partitions:" << num_partitions_ << "\n"
<< "SPMD partitioner only supports tile sharding that includes all "
"partitions. If you didn't add this sharding annotation in the "
"model, please file a bug to XLA team.\n"
<< hlo->ToString();
}
}
}
// Entry computation's parameter and root sharding must be either all
// replicated or all on a single device.
if (!options_.allow_module_signature_change) {
const HloComputation* entry = module->entry_computation();
TF_RET_CHECK(entry->root_instruction()->has_sharding());
const HloSharding& root_sharding = entry->root_instruction()->sharding();
TF_RET_CHECK(root_sharding.IsReplicated() ||
root_sharding.UniqueDevice().has_value())
<< "Unsupported entry root sharding: " << root_sharding.ToString();
for (const HloInstruction* param : entry->parameter_instructions()) {
TF_RET_CHECK(param->has_sharding());
TF_RET_CHECK(param->sharding().IsReplicated() ||
param->sharding().UniqueDevice().has_value())
<< "Unsupported entry parameter sharding:"
<< param->sharding().ToString();
}
}
return Status::OK();
}
} // namespace spmd
} // namespace xla