STT-tensorflow/tensorflow/compiler/xla/service/sharding_propagation.cc
Yuanzhong Xu 18f2dd4262 Allow more sharding propagation from broadcast operands
MaybeImproveInstructionSharding() already checks that it needs to be compatible with existing tiled sharding (if any).

PiperOrigin-RevId: 356375771
Change-Id: Ic5e65e00cea0a25cec2cf93027632c4c7783a20e
2021-02-08 16:13:46 -08:00

1882 lines
76 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/sharding_propagation.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
using ComputationMap =
absl::flat_hash_map<const HloComputation*, HloInstruction*>;
// Returns true iff the specified hlo or sharding has a spatially partitioned
// sharding (tiled or replicated) what can be propagated by sharding
// propagation.
bool IsSpatiallyPartitioned(const HloSharding& sharding) {
if (sharding.IsTuple()) {
return absl::c_any_of(sharding.tuple_elements(), IsSpatiallyPartitioned);
} else {
return !sharding.IsTileMaximal() || sharding.IsReplicated();
}
}
bool IsSpatiallyPartitioned(const HloInstruction* hlo) {
return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding());
}
// Updates the sharding of the specified instruction with the specified sharding
// if it is better than the current one and returns true if a new sharding have
// been applied. If may_combine_partial_sharding is true, this may combine the
// new and existing sharding if they are both partial tiling partial
// replication.
bool MaybeImproveInstructionSharding(HloSharding sharding,
HloInstruction* instruction,
bool may_combine_partial_sharding) {
// We don't want to propagate tile maximal shardings.
if (!IsSpatiallyPartitioned(sharding)) {
return false;
}
// Any sharding is better then no sharding.
if (!instruction->has_sharding()) {
instruction->set_sharding(std::move(sharding));
return true;
}
int64 sharding_tiles = sharding.NumTiles();
if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
may_combine_partial_sharding)) {
// Override existing tiled sharding only when the new sharding is compatible
// with the existing one. This avoids unexpected resharding when `sharding`
// just has more tiles than existing sharding but they are not mergeable.
if (instruction->shape().IsArray() &&
!instruction->sharding().IsTileMaximal() &&
sharding.NumTiles() == sharding_tiles) {
std::vector<int64> diff_dims;
for (int64 i = 0; i < instruction->shape().rank(); ++i) {
if (instruction->sharding().tile_assignment().dim(i) ==
sharding.tile_assignment().dim(i)) {
continue;
}
if (instruction->sharding().tile_assignment().dim(i) != 1) {
return false;
}
diff_dims.push_back(i);
}
if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
sharding, diff_dims) != instruction->sharding()) {
return false;
}
}
instruction->set_sharding(std::move(sharding));
return true;
}
return false;
}
// Sets the sharding for every element within a tuple to replicated (default
// sharding). This is necessary because there is no way to represent a tuple
// sharding when only some of the elements are sharded.
void SetDefaultTupleSharding(HloInstruction* instruction) {
instruction->set_sharding(
HloSharding::SingleTuple(instruction->shape(), HloSharding::Replicate()));
}
// We consider a convolution kernel to be small iff it is smaller along all
// spatial dimensions then the output of the convolution. The rational is that
// we can either shard the kernel or the output and we want to shard the larger
// one for better efficiency.
bool IsConvolutionKernelSmall(const HloInstruction* instruction) {
CHECK_EQ(instruction->opcode(), HloOpcode::kConvolution);
const HloInstruction* rhs = instruction->operand(1);
const auto& dnums = instruction->convolution_dimension_numbers();
for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
int64 kernel_dim =
rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i));
int64 output_dim =
instruction->shape().dimensions(dnums.output_spatial_dimensions(i));
if (kernel_dim >= output_dim) {
return false;
}
}
return true;
}
// Return the operand which is the most suitable for determining the sharding
// for the specified instruction or nullptr if there isn't any suitable operand.
const HloInstruction* PickRepresentativeOperand(
const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kMap:
case HloOpcode::kPad:
case HloOpcode::kPower:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
// For these opcodes the output sharding has to be determined by the
// sharding of the first operand but we can only determine sharding based
// on it if it already has a sharding.
if (instruction->operand(0)->has_sharding()) {
return instruction->operand(0);
}
return nullptr;
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kDivide:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kRemainder:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
case HloOpcode::kXor: {
// For these opcodes the output sharding can be determined by any operand
// so we find the operand with the most specific sharding.
const HloInstruction* best_operand = nullptr;
for (const HloInstruction* operand : instruction->operands()) {
if (operand->has_sharding() &&
(best_operand == nullptr ||
hlo_sharding_util::IsShardingMoreSpecific(
operand->sharding(), best_operand->sharding()))) {
best_operand = operand;
}
}
return best_operand;
}
// There is no suitable operand for the rest of the opcodes.
case HloOpcode::kAddDependency:
case HloOpcode::kAfterAll:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCholesky:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:
case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kDynamicReshape:
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kIota:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kPartitionId:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kReplicaId:
case HloOpcode::kReshape:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace:
case HloOpcode::kTranspose:
case HloOpcode::kTriangularSolve:
case HloOpcode::kTuple:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kSetDimensionSize:
return nullptr;
}
}
bool SupportSpatialPartitioning(const HloInstruction* instruction,
const ComputationMap& computation_map,
bool is_spmd) {
if (instruction->parent()->root_instruction() == instruction &&
computation_map.find(instruction->parent()) == computation_map.end()) {
// We don't support sharding the root instruction of a computation yet,
// unless the computation is a while body.
return false;
}
if (instruction->IsElementwise() &&
(instruction->opcode() != HloOpcode::kRng || is_spmd)) {
return true;
}
switch (instruction->opcode()) {
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kDot:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kGather:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kIota:
case HloOpcode::kPad:
case HloOpcode::kReduceWindow:
case HloOpcode::kReshape:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSlice:
case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
case HloOpcode::kReduce:
return true;
case HloOpcode::kAllReduce:
// Only if channel_id is not specified.
return instruction->channel_id() == absl::nullopt;
case HloOpcode::kParameter:
return computation_map.find(instruction->parent()) !=
computation_map.end();
case HloOpcode::kReverse:
return is_spmd;
default:
return false;
}
}
bool InferDotShardingFromOperands(
HloInstruction* instruction,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
bool may_combine_partial_sharding) {
auto from_operand = [&](int64 operand_index) {
auto operand = instruction->operand(operand_index);
const HloSharding& operand_sharding = operand->sharding();
if (operand_sharding.IsTileMaximal()) {
return operand_sharding;
}
std::vector<int64> contracting_dims;
contracting_dims.reserve(dnums.contracting_dims.size());
for (const auto& dim : dnums.contracting_dims) {
contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs);
}
// It's possible that some size-1 spatial dims of convolutions are parsed as
// non-contracting dims. We might have tiled dimensions on them.
for (const auto& dim : operand_index == 0
? dnums.rhs_non_contracting_dims
: dnums.lhs_non_contracting_dims) {
int64 d = operand_index == 0 ? dim.lhs : dim.rhs;
if (d > 0) {
contracting_dims.push_back(d);
}
}
auto replicate_contracting_dims =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
operand_sharding, contracting_dims);
std::vector<int64> out_dims_to_op_perm(instruction->shape().rank(), -1);
std::vector<int64> op_dims_to_output_perm(operand->shape().rank(), -1);
for (const auto& dim : dnums.batch_dims) {
out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
dim.output;
}
for (const auto& dim : operand_index == 0
? dnums.lhs_non_contracting_dims
: dnums.rhs_non_contracting_dims) {
out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
dim.output;
}
return *hlo_sharding_util::TransposeShardingWithCollapsedDims(
replicate_contracting_dims, op_dims_to_output_perm,
out_dims_to_op_perm);
};
bool changed = false;
int64 larger_operand =
ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) >=
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())
? 0
: 1;
if (IsSpatiallyPartitioned(instruction->operand(larger_operand))) {
changed |= MaybeImproveInstructionSharding(from_operand(larger_operand),
instruction,
may_combine_partial_sharding);
}
if (IsSpatiallyPartitioned(instruction->operand(1 - larger_operand))) {
changed |= MaybeImproveInstructionSharding(from_operand(1 - larger_operand),
instruction,
may_combine_partial_sharding);
}
return changed;
}
bool InferGatherParallelShardingFromOperands(
HloInstruction* instruction,
const hlo_sharding_util::GatherParallelDims& parallel_dims,
bool may_combine_partial_sharding) {
auto from_operand = [instruction](
int64 operand_index,
absl::Span<const int64> output_aligned_parallel_dims,
absl::Span<const int64> output_parallel_dims) {
const HloInstruction* operand = instruction->operand(operand_index);
const HloSharding& operand_sharding = operand->sharding();
if (operand_sharding.IsTileMaximal()) {
return operand_sharding;
}
auto dnums = instruction->gather_dimension_numbers();
std::vector<int64> output_tile_dims(instruction->shape().rank(), 1);
std::vector<int64> index_non_parallel_dims;
index_non_parallel_dims.reserve(operand->shape().rank());
// Detect non parallel dimensions in the index.
for (int i = 0; i < operand->shape().rank(); ++i) {
if (!absl::c_linear_search(output_aligned_parallel_dims, i)) {
index_non_parallel_dims.push_back(i);
}
}
// Collect tile dimensions in the operand. The order of the parallel
// dimensions in output_aligned_parallel_dims is the same as that of the
// output
for (int i = 0; i < output_aligned_parallel_dims.size(); ++i) {
const int64 indices_idx = output_aligned_parallel_dims[i];
const int64 output_idx = output_parallel_dims[i];
output_tile_dims[output_idx] =
operand_sharding.tile_assignment().dim(indices_idx);
}
HloSharding replicate_non_parallel_dims =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
operand_sharding, index_non_parallel_dims);
if (replicate_non_parallel_dims.IsTileMaximal()) {
return replicate_non_parallel_dims;
}
if (replicate_non_parallel_dims.ReplicateOnLastTileDim()) {
output_tile_dims.push_back(
replicate_non_parallel_dims.tile_assignment().dimensions().back());
}
auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment();
output_tile_assignment.Reshape(output_tile_dims);
return replicate_non_parallel_dims.ReplicateOnLastTileDim()
? HloSharding::PartialTile(
output_tile_assignment,
replicate_non_parallel_dims.metadata())
: HloSharding::Tile(output_tile_assignment,
replicate_non_parallel_dims.metadata());
};
bool changed = false;
auto output_parallel_dims =
hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims);
if (IsSpatiallyPartitioned(instruction->operand(0))) {
changed |= MaybeImproveInstructionSharding(
from_operand(
0,
absl::MakeConstSpan(
hlo_sharding_util::GatherOutputAlignedOperandParallelDims(
*instruction, parallel_dims)),
absl::MakeConstSpan(output_parallel_dims)),
instruction, may_combine_partial_sharding);
}
if (IsSpatiallyPartitioned(instruction->operand(1))) {
changed |= MaybeImproveInstructionSharding(
from_operand(1,
absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)),
instruction, may_combine_partial_sharding);
}
return changed;
}
// Convolution handling for InferShardingFromOperands().
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
int64 aggressiveness,
bool may_combine_partial_sharding) {
auto get_partitions_for_dims =
[&](const HloInstruction* inst,
absl::Span<
const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums>
dims,
int lhs_or_rhs) {
int64 partitions = 1;
if (!inst->has_sharding()) {
return partitions;
}
const auto& sharding = inst->sharding();
if (sharding.IsTileMaximal()) {
return partitions;
}
for (const auto& dim : dims) {
if (lhs_or_rhs == 0) {
partitions *= sharding.tile_assignment().dim(dim.lhs);
} else {
CHECK_EQ(lhs_or_rhs, 1);
partitions *= sharding.tile_assignment().dim(dim.rhs);
}
}
return partitions;
};
auto dot_dims =
dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
const int64 lhs_conv_spatial_partitions = get_partitions_for_dims(
instruction->operand(0), dot_dims.conv_spatial_dims, 0);
const int64 rhs_conv_spatial_partitions = get_partitions_for_dims(
instruction->operand(1), dot_dims.conv_spatial_dims, 1);
if (dot_dims.conv_spatial_dims.empty() ||
(lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
instruction->batch_group_count() == 1 &&
instruction->feature_group_count() == 1)) {
return InferDotShardingFromOperands(instruction, dot_dims,
may_combine_partial_sharding);
}
const auto& dnums = instruction->convolution_dimension_numbers();
const HloInstruction* lhs = instruction->operand(0);
auto get_tiled_sharding_based_on_lhs = [&] {
CHECK(!lhs->sharding().IsTileMaximal());
std::vector<int64> output_to_lhs_indices(instruction->shape().rank());
output_to_lhs_indices[dnums.output_batch_dimension()] =
dnums.input_batch_dimension();
output_to_lhs_indices[dnums.output_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
return hlo_sharding_util::TransposeSharding(lhs->sharding(),
output_to_lhs_indices);
};
if (!IsSpatiallyPartitioned(lhs)) {
return false;
}
if (lhs->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(
HloSharding::Replicate(lhs->sharding().metadata()), instruction,
may_combine_partial_sharding);
}
if (IsConvolutionKernelSmall(instruction)) {
// If the kernel is small compared to the input then we can generate an
// output what is sharded the same way as the input.
const auto& tile_assignment = lhs->sharding().tile_assignment();
if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
return false;
}
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
instruction,
may_combine_partial_sharding);
}
// If the kernel is large (e.g backward convolution) then we only support
// replicated output.
return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction,
may_combine_partial_sharding);
}
bool CanPropagateThroughAtAgressiveLevel(const HloInstruction& inst,
int64 aggressiveness) {
// At minimum agressiveness, only allow pass-through ops.
if (aggressiveness < 1 && !inst.IsElementwise() &&
inst.opcode() != HloOpcode::kTranspose &&
inst.opcode() != HloOpcode::kReshape) {
return false;
}
// Broadcast propagation should have at least aggressiveness 2.
if (aggressiveness < 2 && inst.opcode() == HloOpcode::kBroadcast) {
return false;
}
return true;
}
// Tries to update the sharding of the specified instruction based on its
// operands and returns true if the sharding of the instruction have been
// changed and false otherwise.
bool InferShardingFromOperands(HloInstruction* instruction,
const ComputationMap& computation_map,
bool is_spmd, int64 aggressiveness) {
if (!CanPropagateThroughAtAgressiveLevel(*instruction, aggressiveness)) {
return false;
}
// Do not change manual sharding.
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
return false;
}
// Propagate manual sharding. Avoid tuple shaped HLOs that group independent
// together. Reduce, ReduceWindow, and Sort can be tuples but the elements
// are correlated, so we propagate manual sharding through them.
if (!instruction->has_sharding() &&
(instruction->shape().IsArray() ||
instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kSort ||
instruction->opcode() == HloOpcode::kReduceWindow)) {
for (const HloInstruction* op : instruction->operands()) {
if (!op->has_sharding() || !op->sharding().IsManual()) continue;
instruction->set_sharding(HloSharding::Manual(op->sharding().metadata()));
return true;
}
}
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
// If an array shaped HLO doesn't support spatial partitioning but at least
// one of its operand is replicated then we make the HLO replicated as well.
if (instruction->shape().IsTuple() || instruction->operand_count() == 0 ||
instruction == instruction->parent()->root_instruction() ||
instruction->HasSideEffect()) {
return false;
}
for (const HloInstruction* op : instruction->operands()) {
if (op->has_sharding() && op->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(
HloSharding::Replicate(op->sharding().metadata()), instruction,
may_combine_partial_sharding);
}
}
return false;
}
switch (instruction->opcode()) {
case HloOpcode::kGetTupleElement: {
const HloInstruction* operand = instruction->operand(0);
if (!IsSpatiallyPartitioned(operand)) {
return false;
}
HloSharding new_sharding = operand->sharding().GetSubSharding(
operand->shape(), {instruction->tuple_index()});
return MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
case HloOpcode::kTuple: {
if (absl::c_none_of(instruction->operands(),
[](const HloInstruction* hlo) {
return IsSpatiallyPartitioned(hlo);
})) {
// None of the operands have a spatially partitioned sharding.
return false;
}
bool changed = false;
if (!instruction->has_sharding()) {
// Set the sharding for all elements in the tuple because it isn't
// possible to set a partial sharding.
SetDefaultTupleSharding(instruction);
changed = true;
}
// Go through each operand and if the operand has a sharding that is
// better than the current sharding for that tuple element then update
// it.
const Shape& shape = instruction->shape();
std::vector<HloSharding> sub_shardings =
instruction->sharding().tuple_elements();
int64 sub_sharding_index = 0;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const HloInstruction* operand = instruction->operand(i);
if (operand->has_sharding()) {
if (operand->shape().IsTuple()) {
for (int64 i = 0, e = ShapeUtil::GetLeafCount(operand->shape());
i < e; ++i) {
if (hlo_sharding_util::IsShardingMoreSpecific(
operand->sharding().tuple_elements()[i],
sub_shardings[sub_sharding_index + i])) {
sub_shardings[sub_sharding_index + i] =
operand->sharding().tuple_elements()[i];
}
}
} else {
if (hlo_sharding_util::IsShardingMoreSpecific(
operand->sharding(), sub_shardings[sub_sharding_index])) {
sub_shardings[sub_sharding_index] = operand->sharding();
}
}
}
sub_sharding_index += ShapeUtil::GetLeafCount(operand->shape());
}
HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings);
if (new_sharding != instruction->sharding()) {
instruction->set_sharding(std::move(new_sharding));
return true;
}
return changed;
}
case HloOpcode::kReduce: {
// Reduce could have a tuple shape, where the first half of operands are
// the arrays to reduce, and the second half of operands are the init
// values.
bool changed = false;
for (int64 operand_id = 0; operand_id < instruction->operand_count() / 2;
++operand_id) {
const HloInstruction* operand = instruction->operand(operand_id);
if (!IsSpatiallyPartitioned(operand)) {
continue;
}
auto get_maybe_tuple_sharding = [&](HloSharding sharding) {
if (instruction->operand_count() == 2) {
return sharding;
}
std::vector<HloSharding> tuple(instruction->operand_count() / 2,
std::move(sharding));
return HloSharding::Tuple(instruction->shape(), tuple);
};
if (operand->sharding().IsReplicated() ||
(!is_spmd &&
absl::c_any_of(instruction->dimensions(), [operand](int64 dim) {
return operand->sharding().tile_assignment().dim(dim) > 1;
}))) {
// We are reducing along one of the sharded dimensions. We only
// support this in SPMD.
changed |= MaybeImproveInstructionSharding(
get_maybe_tuple_sharding(
HloSharding::Replicate(operand->sharding().metadata())),
instruction, may_combine_partial_sharding);
continue;
}
auto after_partial_replication =
operand->sharding().IsReplicated()
? operand->sharding()
: hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
operand->sharding(), instruction->dimensions());
if (after_partial_replication.IsReplicated()) {
changed |= MaybeImproveInstructionSharding(
get_maybe_tuple_sharding(after_partial_replication), instruction,
may_combine_partial_sharding);
continue;
}
// Use the same sharding for all tuple elements, because they are part
// of the same reduce instruction.
HloSharding new_sharding =
get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
after_partial_replication, instruction->dimensions()));
changed |= MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
return changed;
}
case HloOpcode::kBroadcast: {
// Make forward propagation through broadcast low priority to avoid
// resharding after broadcast.
if (aggressiveness < 3) {
return false;
}
const HloInstruction* op = instruction->operand(0);
if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
return false;
}
// The output will be tiled along the broadcasted dimension the same way
// as the input for the broadcast while the other dimensions are kept
// non-tiled.
std::vector<int64> target_tile_assignment_dimensions;
const auto& dimensions = instruction->dimensions();
for (int64 i = 0; i < instruction->shape().rank(); ++i) {
auto it = absl::c_find(dimensions, i);
if (it == dimensions.end()) {
target_tile_assignment_dimensions.push_back(1);
} else {
const int64 source_dim = std::distance(dimensions.begin(), it);
target_tile_assignment_dimensions.push_back(
op->sharding().tile_assignment().dim(source_dim));
}
}
if (op->sharding().ReplicateOnLastTileDim()) {
target_tile_assignment_dimensions.push_back(
op->sharding().tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = op->sharding().tile_assignment();
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
HloSharding new_sharding =
op->sharding().ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment,
op->sharding().metadata())
: HloSharding::Tile(new_tile_assignment,
op->sharding().metadata());
return MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
case HloOpcode::kConvolution:
return InferConvolutionShardingFromOperands(instruction, aggressiveness,
may_combine_partial_sharding);
case HloOpcode::kTranspose: {
const HloInstruction* input = instruction->operand(0);
if (!IsSpatiallyPartitioned(input)) {
return false;
}
HloSharding sharding = hlo_sharding_util::TransposeSharding(
input->sharding(), instruction->dimensions());
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
may_combine_partial_sharding);
}
case HloOpcode::kReduceWindow: {
if (instruction->shape().IsTuple()) {
// TODO (b/73062247) variadic reduce window is not yet supported here.
return false;
}
const HloInstruction* lhs = instruction->operand(0);
if (!IsSpatiallyPartitioned(lhs)) {
return false;
}
auto has_dilation = [](const WindowDimension& dimensions) {
return dimensions.base_dilation() > 1 ||
dimensions.window_dilation() > 1;
};
if (absl::c_any_of(instruction->window().dimensions(), has_dilation)) {
VLOG(2) << "Not applying sharding to reduce window because dilatation "
"isn't supported yet: "
<< instruction->ToString();
return false;
}
return MaybeImproveInstructionSharding(lhs->sharding(), instruction,
may_combine_partial_sharding);
}
case HloOpcode::kSelectAndScatter: {
// Shard according to first operand, as output keeps the same shape.
const HloInstruction* lhs = instruction->operand(0);
if (!IsSpatiallyPartitioned(lhs)) {
return false;
}
auto has_base_dilation = [](const WindowDimension& dimensions) {
return dimensions.base_dilation() > 1;
};
if (absl::c_any_of(instruction->window().dimensions(),
has_base_dilation)) {
VLOG(2) << "Not applying sharding to select-and-scatter because "
"base dilation isn't supported yet: "
<< instruction->ToString();
return false;
}
return MaybeImproveInstructionSharding(lhs->sharding(), instruction,
may_combine_partial_sharding);
}
case HloOpcode::kReshape: {
if (!IsSpatiallyPartitioned(instruction->operand(0))) {
return false;
}
absl::optional<HloSharding> new_sharding =
hlo_sharding_util::ReshapeSharding(
instruction->operand(0)->shape(), instruction->shape(),
instruction->operand(0)->sharding());
if (new_sharding.has_value()) {
return MaybeImproveInstructionSharding(std::move(*new_sharding),
instruction,
may_combine_partial_sharding);
}
return false;
}
case HloOpcode::kReverse: {
if (!IsSpatiallyPartitioned(instruction->operand(0))) {
return false;
}
return MaybeImproveInstructionSharding(
hlo_sharding_util::ReverseSharding(
instruction->operand(0)->sharding(), instruction->dimensions()),
instruction, may_combine_partial_sharding);
}
case HloOpcode::kDot: {
const auto& dnums =
dot_as_convolution_util::ParseDotGeneralFromDot(instruction);
return InferDotShardingFromOperands(instruction, dnums,
may_combine_partial_sharding);
}
case HloOpcode::kParameter: {
auto parent_it = computation_map.find(instruction->parent());
if (parent_it == computation_map.end()) {
return false;
}
const HloInstruction* parent = parent_it->second;
switch (parent->opcode()) {
case HloOpcode::kConditional: {
for (int64 i = 1; i < parent->operand_count(); ++i) {
if (parent->called_computations()[i - 1] == instruction->parent()) {
if (parent->operand(i)->has_sharding()) {
return MaybeImproveInstructionSharding(
parent->operand(i)->sharding(), instruction,
may_combine_partial_sharding);
}
return false;
}
}
return false;
}
default:
return false;
}
}
case HloOpcode::kSort: {
const HloInstruction* operand = PickRepresentativeOperand(instruction);
if (!operand || !IsSpatiallyPartitioned(operand)) {
return false;
}
if (!operand->sharding().IsTileMaximal() &&
operand->sharding().tile_assignment().dim(
instruction->dimensions(0)) != 1) {
// Doesn't support sharding the sorting dimension.
return false;
}
if (instruction->shape().IsTuple()) {
return MaybeImproveInstructionSharding(
HloSharding::SingleTuple(instruction->shape(), operand->sharding()),
instruction, may_combine_partial_sharding);
} else {
return MaybeImproveInstructionSharding(operand->sharding(), instruction,
may_combine_partial_sharding);
}
}
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice: {
auto propagate_slicing = [&]() {
const HloInstruction* operand =
instruction->opcode() == HloOpcode::kDynamicSlice
? instruction->operand(0)
: instruction->operand(1);
if (!IsSpatiallyPartitioned(operand)) {
return false;
}
if (operand->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(
HloSharding::Replicate(operand->sharding().metadata()),
instruction, may_combine_partial_sharding);
}
const auto& tile_assignment = operand->sharding().tile_assignment();
for (int64 i = 0; i < instruction->shape().rank(); ++i) {
if (tile_assignment.dim(i) > 1 &&
instruction->shape().dimensions(i) !=
operand->shape().dimensions(i)) {
return false;
}
}
return MaybeImproveInstructionSharding(operand->sharding(), instruction,
may_combine_partial_sharding);
};
auto propagate_base = [&]() {
if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) {
return false;
}
if (!IsSpatiallyPartitioned(instruction->operand(0))) {
return false;
}
return MaybeImproveInstructionSharding(
instruction->operand(0)->sharding(), instruction,
may_combine_partial_sharding);
};
return propagate_slicing() || propagate_base();
}
case HloOpcode::kGather: {
bool changed = false;
if (IsSpatiallyPartitioned(instruction->operand(1))) {
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);
changed |= MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
if (is_spmd) {
auto gather_parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
if (gather_parallel_dims) {
changed |= InferGatherParallelShardingFromOperands(
instruction, *gather_parallel_dims, may_combine_partial_sharding);
}
if (IsSpatiallyPartitioned(instruction->operand(0))) {
absl::Span<const int64> operand_parallel_dims;
if (gather_parallel_dims) {
operand_parallel_dims = absl::MakeConstSpan(
gather_parallel_dims->operand_parallel_dims);
}
HloSharding filtered_operand_sharding =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
instruction->operand(0)->sharding(), operand_parallel_dims);
auto maybe_from_data =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
filtered_operand_sharding, *instruction, instruction->shape(),
instruction->operand(0)->shape());
if (maybe_from_data) {
changed |= MaybeImproveInstructionSharding(
std::move(*maybe_from_data), instruction,
may_combine_partial_sharding);
}
}
}
return changed;
}
case HloOpcode::kScatter: {
bool changed = false;
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
changed |= MaybeImproveInstructionSharding(
instruction->operand(0)->sharding(), instruction,
may_combine_partial_sharding);
}
if (!IsSpatiallyPartitioned(instruction->operand(1)) &&
!IsSpatiallyPartitioned(instruction->operand(2))) {
return false;
}
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(2))) {
auto maybe_from_update =
hlo_sharding_util::ScatterOutputShardingFromUpdate(
instruction->operand(2)->sharding(), *instruction);
if (maybe_from_update) {
changed |= MaybeImproveInstructionSharding(
std::move(*maybe_from_update), instruction,
may_combine_partial_sharding);
}
}
changed |= MaybeImproveInstructionSharding(
HloSharding::Replicate(), instruction, may_combine_partial_sharding);
return changed;
}
case HloOpcode::kWhile: {
if (!instruction->operand(0)->has_sharding()) {
return false;
}
auto sharding = instruction->operand(0)->sharding();
if (instruction->has_sharding()) {
hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
may_combine_partial_sharding);
}
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
may_combine_partial_sharding);
}
default: {
if (instruction->IsElementwise() && may_combine_partial_sharding) {
bool changed = false;
for (auto operand : instruction->operands()) {
if (IsSpatiallyPartitioned(operand)) {
changed |= MaybeImproveInstructionSharding(
operand->sharding(), instruction, may_combine_partial_sharding);
}
}
return changed;
}
const HloInstruction* operand = PickRepresentativeOperand(instruction);
if (!operand || !IsSpatiallyPartitioned(operand)) {
return false;
}
return MaybeImproveInstructionSharding(operand->sharding(), instruction,
may_combine_partial_sharding);
}
}
return false;
}
HloSharding InferDotOperandSharding(
const HloInstruction* instruction,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
int64 operand_index, bool may_combine_partial_sharding) {
auto operand = instruction->operand(operand_index);
auto other = instruction->operand(1 - operand_index);
std::vector<int64> output_dims_to_replicate;
std::vector<int64> other_operand_dims_to_replicate;
for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims
: dnums.lhs_non_contracting_dims) {
output_dims_to_replicate.push_back(dim.output);
other_operand_dims_to_replicate.push_back(operand_index == 0 ? dim.rhs
: dim.lhs);
}
// If this dot is interpreted from a conv, then contracting dims may have
// corresponding spatial dimensions in the output, and this operand's
// non-contracting dims may have corresponding spatial dims in the other
// operand.
for (const auto& dim : dnums.contracting_dims) {
if (dim.output >= 0) {
output_dims_to_replicate.push_back(dim.output);
}
}
for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims
: dnums.rhs_non_contracting_dims) {
int64 other_dim = operand_index == 0 ? dim.rhs : dim.lhs;
if (other_dim >= 0) {
other_operand_dims_to_replicate.push_back(other_dim);
}
}
auto output_other_dims_replicated =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
instruction->sharding(), output_dims_to_replicate);
std::vector<int64> output_to_operand_dims(instruction->shape().rank(), -1);
std::vector<int64> operand_to_output_dims(operand->shape().rank(), -1);
for (const auto& dim : dnums.batch_dims) {
output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output;
}
for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims
: dnums.rhs_non_contracting_dims) {
output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output;
}
auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims(
output_other_dims_replicated, output_to_operand_dims,
operand_to_output_dims);
if (IsSpatiallyPartitioned(other)) {
auto other_operand_dims_replicated =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
other->sharding(), other_operand_dims_to_replicate);
std::vector<int64> other_to_operand_dims(other->shape().rank(), -1);
std::vector<int64> operand_to_other_dims(operand->shape().rank(), -1);
for (const auto& dim : dnums.batch_dims) {
other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] =
operand_index == 0 ? dim.lhs : dim.rhs;
operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
operand_index == 0 ? dim.rhs : dim.lhs;
}
for (const auto& dim : dnums.contracting_dims) {
other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] =
operand_index == 0 ? dim.lhs : dim.rhs;
operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
operand_index == 0 ? dim.rhs : dim.lhs;
}
HloSharding sharding_from_other =
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
other_operand_dims_replicated, other_to_operand_dims,
operand_to_other_dims);
if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other,
may_combine_partial_sharding)) {
sharding = std::move(sharding_from_other);
}
}
return sharding;
}
// Return the sharding that should be propagated from user to instruction.
absl::optional<HloSharding> GetShardingFromUser(
const HloInstruction& instruction, const HloInstruction& user,
int64 aggressiveness, bool is_spmd) {
if (!CanPropagateThroughAtAgressiveLevel(user, aggressiveness)) {
return absl::nullopt;
}
if (!IsSpatiallyPartitioned(&user)) {
return absl::nullopt;
}
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
switch (user.opcode()) {
case HloOpcode::kBroadcast: {
if (user.sharding().IsReplicated()) {
return user.sharding();
}
std::vector<int64> dims_to_replicate;
bool needs_replication = false;
for (int64 i = 0; i < user.shape().rank(); ++i) {
if (absl::c_count(user.dimensions(), i) == 0) {
dims_to_replicate.push_back(i);
if (user.sharding().tile_assignment().dim(i) > 1) {
needs_replication = true;
}
}
}
// If not SPMD, only support when none of the partitioned dimensions in
// the broadcast output belong to new dimensions.
if (!is_spmd && needs_replication) {
return absl::nullopt;
}
return hlo_sharding_util::RemoveShapeDimensions(
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
user.sharding(), dims_to_replicate),
dims_to_replicate);
}
case HloOpcode::kConcatenate: {
if (user.sharding().IsReplicated()) {
return user.sharding();
}
const int64 cdim = user.concatenate_dimension();
const Array<int64>& tile_assignment = user.sharding().tile_assignment();
if (tile_assignment.dim(cdim) == 1) {
// If we are concatenating along a non-sharded dimension then the
// operands should have the same sharding as the result.
return user.sharding();
}
if (is_spmd) {
// SPMD doesn't support tiling with part of the devices. Return the same
// sharding.
return user.sharding();
}
// If we are concatenating along a sharded dimension then we want the
// operands to be distributed among the devices their data is used.
int64 start_offset = 0;
for (HloInstruction* op : user.operands()) {
if (op == &instruction) {
break;
}
start_offset += op->shape().dimensions(cdim);
}
const int64 tile_shape = CeilOfRatio(user.shape().dimensions(cdim),
tile_assignment.dimensions()[cdim]);
std::vector<int64> start_indices(tile_assignment.num_dimensions());
std::vector<int64> end_indices = tile_assignment.dimensions();
start_indices[cdim] = start_offset / tile_shape;
end_indices[cdim] = CeilOfRatio(
start_offset + instruction.shape().dimensions(cdim), tile_shape);
auto new_tile_assignment =
tile_assignment.Slice(start_indices, end_indices);
if (new_tile_assignment.num_elements() == 1) {
return HloSharding::AssignDevice(*new_tile_assignment.begin(),
user.sharding().metadata());
}
return HloSharding::Tile(new_tile_assignment, user.sharding().metadata());
}
case HloOpcode::kConvolution: {
auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
if (dot_dims.conv_spatial_dims.empty()) {
int64 op_idx = user.operand_index(&instruction);
return InferDotOperandSharding(&user, dot_dims, op_idx,
may_combine_partial_sharding);
}
return absl::nullopt;
}
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice: {
if (user.sharding().IsReplicated()) {
return user.sharding();
}
if (user.opcode() == HloOpcode::kDynamicUpdateSlice &&
&instruction == user.operand(0)) {
return user.sharding();
}
const HloInstruction* operand = user.opcode() == HloOpcode::kDynamicSlice
? user.operand(0)
: user.operand(1);
if (&instruction != operand) {
return absl::nullopt;
}
const auto& tile_assignment = user.sharding().tile_assignment();
for (int64 i = 0; i < user.shape().rank(); ++i) {
if (tile_assignment.dim(i) > 1 &&
user.shape().dimensions(i) != operand->shape().dimensions(i)) {
return absl::nullopt;
}
}
return user.sharding();
}
case HloOpcode::kReduceWindow: {
if (user.shape().IsTuple()) {
auto sub_sharding = user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
return sub_sharding;
}
if (&instruction != user.operand(0)) {
return absl::nullopt;
}
return user.sharding();
}
case HloOpcode::kReshape: {
return hlo_sharding_util::ReshapeSharding(
user.shape(), instruction.shape(), user.sharding());
}
case HloOpcode::kPad: {
if (&instruction != user.operand(0)) {
return absl::nullopt;
}
return user.sharding();
}
case HloOpcode::kSlice: {
return user.sharding();
}
case HloOpcode::kTranspose: {
// Calculate the dimension numbers for reversing the current transpose
// and then use TransposeSharding to convert the output sharding to an
// input sharding.
std::vector<int64> reverse_dimensions(user.dimensions().size());
for (int64 i = 0; i < user.dimensions().size(); ++i) {
reverse_dimensions[user.dimensions(i)] = i;
}
return hlo_sharding_util::TransposeSharding(user.sharding(),
reverse_dimensions);
}
case HloOpcode::kTuple: {
auto sub_sharding = user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
return sub_sharding;
}
case HloOpcode::kGetTupleElement: {
HloSharding new_sharding =
instruction.has_sharding()
? instruction.sharding()
: HloSharding::SingleTuple(instruction.shape(),
HloSharding::Replicate());
int64 sharding_index = 0;
for (int64 i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
if (i == user.tuple_index()) {
break;
}
if (instruction.shape().tuple_shapes(i).IsArray()) {
sharding_index += 1;
} else {
sharding_index +=
instruction.shape().tuple_shapes(i).tuple_shapes_size();
}
}
if (user.shape().IsArray()) {
new_sharding.tuple_elements()[sharding_index] = user.sharding();
}
for (int64 i = 0; i < user.sharding().tuple_elements().size(); ++i) {
new_sharding.tuple_elements()[sharding_index + i] =
user.sharding().tuple_elements()[i];
}
return new_sharding;
}
case HloOpcode::kDot: {
int64 op_idx = user.operand_index(&instruction);
auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(&user);
return InferDotOperandSharding(&user, dnums, op_idx,
may_combine_partial_sharding);
}
case HloOpcode::kReduce: {
if (instruction.shape().rank() == 0) {
return absl::nullopt;
}
auto user_sharding =
user.shape().IsTuple()
? user.sharding().GetSubSharding(
user.shape(), {user.operand_index(&instruction)})
: user.sharding();
if (user_sharding.IsTileMaximal()) {
return user_sharding;
}
std::vector<int64> target_tile_assignment_dimensions(
instruction.shape().rank() +
(user_sharding.ReplicateOnLastTileDim() ? 1 : 0));
const auto& dimensions = user.dimensions();
int64 next_output_dim = 0;
for (int64 i = 0; i < target_tile_assignment_dimensions.size(); ++i) {
if (absl::c_find(dimensions, i) == dimensions.end()) {
target_tile_assignment_dimensions[i] =
user_sharding.tile_assignment().dim(next_output_dim++);
} else {
target_tile_assignment_dimensions[i] = 1;
}
}
auto tile_assignment = user_sharding.tile_assignment();
tile_assignment.Reshape(target_tile_assignment_dimensions);
return user_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment,
user_sharding.metadata())
: HloSharding::Tile(tile_assignment, user_sharding.metadata());
}
case HloOpcode::kSort: {
HloSharding user_sharding = user.sharding();
if (user_sharding.IsTuple()) {
return user_sharding = user_sharding.GetSubSharding(
user.shape(), {user.operand_index(&instruction)});
}
return user_sharding;
}
case HloOpcode::kReverse: {
return hlo_sharding_util::ReverseSharding(user.sharding(),
user.dimensions());
}
case HloOpcode::kGather: {
if (&instruction == user.operand(1)) {
return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
}
if (is_spmd) {
return hlo_sharding_util::GatherDataOperandShardingFromOutput(
user.sharding(), user);
}
return absl::nullopt;
}
case HloOpcode::kScatter: {
if (&instruction == user.operand(0)) {
return user.sharding();
}
if (&instruction == user.operand(1)) {
auto update = user.operand(2);
if (!IsSpatiallyPartitioned(update)) {
return absl::nullopt;
}
return hlo_sharding_util::ScatterIndexSharding(update->sharding(),
&user);
}
CHECK_EQ(&instruction, user.operand(2));
auto indices = user.operand(1);
if (IsSpatiallyPartitioned(indices)) {
auto from_indices =
hlo_sharding_util::ScatterDataSharding(indices->sharding(), &user);
if (!from_indices.IsTileMaximal()) {
return from_indices;
}
}
if (is_spmd) {
return hlo_sharding_util::ScatterUpdateShardingFromOutput(
user.sharding(), user);
}
return absl::nullopt;
}
default: {
// If the user output shape is compatible with the current instruction
// shape excluding element type and the current instruction is supported
// by spatial partitioning, then the user sharding can be used for
// propagation to the current instruction.
if (ShapeUtil::CompatibleIgnoringElementType(instruction.shape(),
user.shape())) {
return user.sharding();
}
return absl::nullopt;
}
}
}
// Tries to update the sharding of the specified instruction based on its users
// and returns true if the sharding of the instruction have been changed and
// false otherwise.
bool InferShardingFromUsers(HloInstruction* instruction,
const ComputationMap& computation_map,
int64 aggressiveness, bool is_spmd) {
if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) {
return false;
}
// Do not change manual sharding.
if (instruction->has_sharding() && instruction->sharding().IsManual()) {
return false;
}
// Propagate manual sharding.
if (!instruction->has_sharding() && instruction->shape().IsArray()) {
for (const HloInstruction* user : instruction->users()) {
if (!user->has_sharding() || !user->sharding().IsManual() ||
user->IsCustomCall("SPMDFullToShardShape"))
continue;
instruction->set_sharding(
HloSharding::Manual(user->sharding().metadata()));
return true;
}
}
if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) {
return false;
}
bool improved_sharding = false;
const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
for (const HloInstruction* user : instruction->users()) {
absl::optional<HloSharding> user_sharding =
GetShardingFromUser(*instruction, *user, aggressiveness, is_spmd);
if (user_sharding) {
improved_sharding |= MaybeImproveInstructionSharding(
std::move(*user_sharding), instruction, may_combine_partial_sharding);
}
}
return improved_sharding;
}
// Checks if two HloShardings have the same metadata attached.
bool SameShardingMetadata(const HloSharding& a, const HloSharding& b) {
DCHECK_EQ(a, b);
auto same_metadata = [](absl::Span<const OpMetadata> a,
absl::Span<const OpMetadata> b) {
if (a.size() != b.size()) return false;
for (int i = 0, e = a.size(); i < e; ++i) {
if (!protobuf_util::ProtobufEquals(a[i], b[i])) {
return false;
}
}
return true;
};
if (a.IsTuple()) {
for (int i = 0, e = a.tuple_elements().size(); i < e; ++i) {
if (!same_metadata(a.tuple_elements()[i].metadata(),
b.tuple_elements()[i].metadata())) {
return false;
}
}
return true;
} else {
return same_metadata(a.metadata(), b.metadata());
}
}
// Assigns metadata to optional sharding on instructions if instructions have
// metadata. If sharding already has some metadata, no new metadata will be
// added.
bool AssignShardingMetadata(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
const auto& metadata = instruction->metadata();
if (!instruction->has_sharding() || metadata.ByteSizeLong() == 0) {
continue;
}
HloSharding sharding_with_metadata =
instruction->sharding().WithMetadata({metadata}, /*overwrite=*/false);
if (!SameShardingMetadata(instruction->sharding(),
sharding_with_metadata)) {
instruction->set_sharding(std::move(sharding_with_metadata));
changed = true;
}
}
}
return changed;
}
// Removes all sharding metadata from shardings on instructions.
bool RemoveShardingMetadata(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (!instruction->has_sharding()) {
continue;
}
HloSharding sharding_no_metadata =
instruction->sharding().WithoutMetadata();
if (!SameShardingMetadata(instruction->sharding(),
sharding_no_metadata)) {
instruction->set_sharding(std::move(sharding_no_metadata));
changed = true;
}
}
}
return changed;
}
// Remove Sharding custom-call instruction by folding the sharding attribute
// to its operand. If the operand already has a different sharding, insert a
// copy node for reshard.
StatusOr<bool> ProcessShardingInstruction(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
auto instructions = computation->MakeInstructionPostOrder();
std::reverse(instructions.begin(), instructions.end());
for (HloInstruction* instruction : instructions) {
if (instruction->opcode() != HloOpcode::kCustomCall) {
continue;
}
if (instruction->custom_call_target() != "Sharding") {
continue;
}
TF_RET_CHECK(instruction->has_sharding())
<< "Sharding instruction must have a sharding attribute";
const HloSharding& sharding = instruction->sharding();
// If the operand has a different sharding from the current sharding
// instruction, create a copy node. Otherwise, just remove the sharding
// instruction and set the operand sharding.
if (instruction->operand(0)->has_sharding() &&
instruction->operand(0)->sharding() != sharding) {
auto copy = computation->AddInstruction(
HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy,
instruction->mutable_operand(0)));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy));
copy->set_sharding(sharding);
} else {
instruction->mutable_operand(0)->set_sharding(sharding);
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
}
changed = true;
}
}
return changed;
}
// If a while contains a channel instruction on device D, check that any other
// instructions with a device assignment are on D. Further, annotate the root
// instruction of the while body to ensure that HLO partitioning will keep the
// entire while instruction on D.
Status CheckAndUpdateDeviceAssignmentsInWhileBody(
HloInstruction* while_instruction) {
auto bad_status = [](HloInstruction* instruction, int64 device,
HloInstruction* channel_instruction,
int64 correct_device) {
return FailedPrecondition(
"Instruction: %s is on device: %d, which conflicts with device: %d "
"of channel instruction: %s",
instruction->name(), device, correct_device,
channel_instruction->name());
};
CHECK_EQ(while_instruction->opcode(), HloOpcode::kWhile);
HloComputation* while_body = while_instruction->while_body();
// Maps a device number to an instruction in the while_body with that
// device assignment.
std::map<int64, HloInstruction*> devices_to_instructions;
absl::optional<int64> unique_device = absl::nullopt;
HloInstruction* channel_instruction = nullptr;
for (HloInstruction* instruction : while_body->instructions()) {
if (instruction->sharding_unique_device()) {
auto opcode = instruction->opcode();
int64 device = *instruction->sharding_unique_device();
if (unique_device.has_value()) {
if (*unique_device != device) {
return bad_status(instruction, device, channel_instruction,
*unique_device);
}
} else if (opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
// Cross-replica AllReduces don't have a channel_id, and we
// don't enforce any invariant about their device assignment.
(opcode == HloOpcode::kAllReduce &&
instruction->channel_id())) {
channel_instruction = instruction;
unique_device = device;
if (!devices_to_instructions.empty()) {
for (auto it = devices_to_instructions.begin();
it != devices_to_instructions.end(); ++it) {
if (*unique_device != it->first) {
return bad_status(it->second, it->first, channel_instruction,
*unique_device);
}
}
}
} else {
devices_to_instructions[device] = instruction;
}
}
}
if (unique_device.has_value()) {
auto while_device = while_instruction->sharding_unique_device();
if (while_device.has_value() && *unique_device != *while_device) {
return bad_status(while_instruction, *while_device, channel_instruction,
*unique_device);
}
auto body_root = while_body->root_instruction();
auto root_device = body_root->sharding_unique_device();
if (!root_device.has_value()) {
body_root->set_device_sharding(*unique_device);
} else if (*unique_device != *root_device) {
return bad_status(body_root, *root_device, channel_instruction,
*unique_device);
}
}
return Status::OK();
}
} // namespace
/*static*/ Status ShardingPropagation::NormalizeDomain(
const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
if (metadata != nullptr) {
TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
ShardingMetadata::ToShardingMetadata(metadata));
const auto& sharding = sharding_metadata->sharding();
if (sharding != nullptr) {
bool is_spatially_partitioned = !sharding->HasUniqueDevice();
if (sharding->IsTuple()) {
is_spatially_partitioned = absl::c_any_of(
sharding->tuple_elements(),
[](const HloSharding& s) { return !s.HasUniqueDevice(); });
}
if (is_spatially_partitioned) {
for (HloInstruction* d : domain.exit_domains) {
d->mutable_operand(0)->set_sharding(*sharding);
}
return Status::OK();
}
}
}
return ShardingMetadata::NormalizeShardingDomain(domain, metadata);
}
StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
bool any_changed = propagate_metadata_ ? AssignShardingMetadata(module)
: RemoveShardingMetadata(module);
auto status_or_changed = ProcessShardingInstruction(module);
if (!status_or_changed.ok()) return status_or_changed;
any_changed |= status_or_changed.ValueOrDie();
// Association of partitionable embedded computations with their parent
// instruction.
ComputationMap computation_map;
// Instructions that are related through a computation and need to share the
// same sharding.
auto get_related_instructions = [](HloInstruction* inst) {
if (inst->opcode() == HloOpcode::kWhile) {
return std::vector<HloInstruction*>{
inst, inst->while_body()->root_instruction(),
inst->while_body()->parameter_instruction(0),
inst->while_condition()->parameter_instruction(0)};
} else if (inst->opcode() == HloOpcode::kConditional) {
std::vector<HloInstruction*> comps{inst};
for (HloComputation* c : inst->called_computations()) {
comps.push_back(c->root_instruction());
}
return comps;
} else {
CHECK(false);
}
};
// If instruction is a while, or the root or a parameter of a while body,
// then propagate its sharding to the while instruction, to its body root,
// and to its condition parameter.
std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
maybe_computation_propagation = [&](HloInstruction* instruction,
absl::flat_hash_set<HloInstruction*>*
changed) {
auto propagate_to_instruction = [&](HloInstruction* search_inst) {
auto related_instructions = get_related_instructions(search_inst);
if (absl::c_count(related_instructions, instruction)) {
for (HloInstruction* inst : related_instructions) {
if (!inst->has_sharding() ||
inst->sharding() != instruction->sharding()) {
VLOG(2) << "Add computation sharding: " << inst->name();
inst->set_sharding(instruction->sharding());
changed->insert(inst);
maybe_computation_propagation(inst, changed);
}
}
}
};
if (instruction->opcode() == HloOpcode::kConditional ||
instruction->opcode() == HloOpcode::kWhile) {
propagate_to_instruction(instruction);
}
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->parent()->root_instruction() == instruction) {
auto it = computation_map.find(instruction->parent());
if (it != computation_map.end()) {
propagate_to_instruction(it->second);
}
}
};
for (auto computation : module->computations()) {
for (auto instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kWhile) {
TF_RETURN_IF_ERROR(
CheckAndUpdateDeviceAssignmentsInWhileBody(instruction));
}
}
}
// Populate computation_map in order to associate while bodies to their
// while instructions.
for (auto computation : module->computations()) {
for (auto instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kWhile ||
instruction->opcode() == HloOpcode::kConditional) {
// Check if any of the related instructions has sharding, in which case
// propagate it to the other instructions, so they all share the same
// sharding, in case the user didn't shard all of them. We don't check
// that user shardings are consistent, because such check is already
// done by HloShardingVerifier.
const HloInstruction* sharded_inst = nullptr;
auto related_instructions = get_related_instructions(instruction);
for (auto inst : related_instructions) {
if (inst->has_sharding()) {
sharded_inst = inst;
break;
}
}
if (sharded_inst != nullptr) {
// Set the same sharding to all the other related instructions.
for (auto inst : related_instructions) {
inst->set_sharding(sharded_inst->sharding());
}
}
if (instruction->opcode() == HloOpcode::kWhile) {
computation_map[instruction->while_body()] = instruction;
} else {
for (HloComputation* c : instruction->called_computations()) {
computation_map[c] = instruction;
}
}
}
}
}
// Collect all pre-sharded instructions as we aren't allowed to modify their
// sharding.
absl::flat_hash_set<const HloInstruction*> provided_shardings;
for (const HloComputation* computation : module->computations()) {
for (const HloInstruction* inst : computation->instructions()) {
if (inst->has_sharding()) {
provided_shardings.insert(inst);
}
}
}
// Consider the root instruction of the entry module as one with provided
// sharding as its sharding have to match with the one expected by the host.
provided_shardings.insert(module->entry_computation()->root_instruction());
// Iterate to a fixpoint that is guaranteed to be reached because we only
// strictly improve the sharding of the graph and it can't be improved
// indefinitely.
int64 iterations = 0;
auto run_to_fix_point = [&](int64 aggressiveness) {
absl::flat_hash_set<const HloInstruction*> already_inferred_from_operands;
absl::flat_hash_set<const HloInstruction*> already_inferred_from_users;
bool changed_last_iter = true;
while (changed_last_iter) {
changed_last_iter = false;
int64 inferred_from_operand_counter = 0;
int64 inferred_from_user_counter = 0;
int64 instruction_counter = 0;
int64 already_sharded_counter = 0;
for (const HloComputation* computation : module->computations()) {
std::vector<HloInstruction*> instructions =
computation->MakeInstructionPostOrder();
instruction_counter += instructions.size();
for (const HloInstruction* instruction : instructions) {
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
}
auto clear_cache = [&](HloInstruction* hlo) {
for (auto operand : hlo->operands()) {
already_inferred_from_users.erase(operand);
}
for (auto user : hlo->users()) {
already_inferred_from_operands.erase(user);
}
};
// First iterate the HLO graph in post order taking shardings from
// operands.
for (HloInstruction* instruction : instructions) {
if (already_inferred_from_operands.contains(instruction) ||
provided_shardings.contains(instruction)) {
continue;
}
already_inferred_from_operands.insert(instruction);
if (InferShardingFromOperands(instruction, computation_map, is_spmd_,
aggressiveness)) {
++inferred_from_operand_counter;
any_changed = true;
VLOG(2) << "Add sharding (forward-pass): "
<< instruction->ToString();
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
maybe_computation_propagation(instruction, &changed_in_comp_prop);
clear_cache(instruction);
for (auto hlo : changed_in_comp_prop) {
clear_cache(hlo);
}
changed_last_iter = true;
}
}
// Then iterate the HLO graph in reverse post order taking shardings
// from users.
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
if (already_inferred_from_users.contains(*it) ||
provided_shardings.contains(*it)) {
continue;
}
already_inferred_from_users.insert(*it);
if (InferShardingFromUsers(*it, computation_map, aggressiveness,
is_spmd_)) {
++inferred_from_user_counter;
any_changed = true;
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
maybe_computation_propagation(*it, &changed_in_comp_prop);
clear_cache(*it);
for (auto hlo : changed_in_comp_prop) {
clear_cache(hlo);
}
changed_last_iter = true;
}
}
}
VLOG(1) << "Sharding propagation iteration " << iterations << ";";
VLOG(1) << " total instructions: " << instruction_counter;
VLOG(1) << " instructions already sharded: " << already_sharded_counter;
VLOG(1) << " shardings inferred from operands: "
<< inferred_from_operand_counter;
VLOG(1) << " shardings inferred from users: "
<< inferred_from_user_counter;
VLOG(1) << " aggressiveness: " << aggressiveness;
++iterations;
}
};
for (int64 aggressiveness = 0; aggressiveness < 4; ++aggressiveness) {
run_to_fix_point(aggressiveness);
}
VLOG(1) << "Sharding propagation completed after " << iterations
<< " iterations";
return any_changed;
}
} // namespace xla