Yuanzhong Xu 7893e4bcc1 [XLA] Avoid quadratic behavior in DevicesForSharding
PiperOrigin-RevId: 333652350
Change-Id: I8f0b73f5a584cfcb462a2083ede813331c50c90e
2020-09-24 20:08:08 -07:00

984 lines
38 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include <map>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace hlo_sharding_util {
absl::optional<int64> SelectDominantDevice(
const std::map<int64, int64>& device_map, int64* top_count) {
int64 device = 0;
int64 count = 0;
for (auto& it : device_map) {
if (it.second > count) {
count = it.second;
device = it.first;
}
}
if (top_count != nullptr) {
*top_count = count;
}
return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
}
Status AssignComputationDevice(HloComputation* computation, int64 device) {
VLOG(4) << "Assigning device " << device << " to " << computation->name()
<< " computation";
for (HloInstruction* instruction : computation->instructions()) {
if (!instruction->has_sharding()) {
VLOG(4) << "Assigning device " << device << " to " << instruction->name();
instruction->set_device_sharding(device);
}
}
return Status::OK();
}
absl::optional<int64> GetMostOccurringDevice(
absl::Span<HloInstruction* const> instructions) {
std::map<int64, int64> device_map;
for (HloInstruction* instruction : instructions) {
if (instruction->has_sharding()) {
for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
// The UsedDevices() API returns a map<device, occurrence_count>.
device_map[it.first] += it.second;
}
}
}
return SelectDominantDevice(device_map, nullptr);
}
StatusOr<absl::optional<int64>> GetDominantDevice(
absl::Span<HloComputation* const> computations, double dominant_factor) {
int64 instruction_count = 0;
std::map<int64, int64> device_map;
for (HloComputation* computation : computations) {
for (HloInstruction* instruction : computation->instructions()) {
int64 count = 1;
if (instruction->has_sharding()) {
for (auto& it : instruction->sharding().UsedDevices(&count)) {
// The UsedDevices() API returns a map<device, occurrence_count>.
device_map[it.first] += it.second;
}
}
instruction_count += count;
}
}
int64 count;
absl::optional<int64> device = SelectDominantDevice(device_map, &count);
absl::optional<int64> dominant_device;
if (device) {
double factor =
static_cast<double>(count) / static_cast<double>(instruction_count);
if (factor >= dominant_factor) {
dominant_device = device;
}
}
return dominant_device;
}
HloSharding TransposeSharding(const HloSharding& sharding,
const std::vector<int64>& dimensions) {
if (sharding.IsTileMaximal()) {
return sharding;
}
auto perm_dimensions = dimensions;
if (sharding.ReplicateOnLastTileDim() &&
dimensions.size() < sharding.tile_assignment().num_dimensions()) {
perm_dimensions.push_back(dimensions.size());
}
const int64 rank = perm_dimensions.size();
std::vector<int64> tile_assignment_dim(rank);
for (int64 i = 0; i < rank; ++i) {
tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]);
}
Array<int64> tile_assignment = sharding.tile_assignment();
tile_assignment.Reshape(tile_assignment_dim);
tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
std::vector<int64> src_indices(indices.size(), -1);
for (int64 i = 0; i < indices.size(); ++i) {
src_indices[perm_dimensions[i]] = indices[i];
}
*value = sharding.tile_assignment()(src_indices);
});
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
}
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
const Shape& target_shape,
const HloSharding& sharding) {
if (sharding.IsTileMaximal()) {
return sharding;
}
// In case of a tiled sharding the reshaped sharding will be a valid if the
// reshape is composed from the following operations:
// * Adding or removing dimensions with size 1.
// * Merging consecutive dimensions where only the most major is sharded.
// * Splitting a dimension to consecutive dimensions.
// * Any reshaping of unsharded dimensions.
// Note that merge and split can happen consecutively on the same dimension,
// e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
// gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
// to make supporting such cases easy.
const Shape tile_shape = sharding.TileShape(source_shape);
std::vector<int64> target_tile_assignment_dimensions;
std::vector<int64> source_dims_stack(source_shape.rank());
std::vector<int64> target_dims_stack(target_shape.rank());
std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
for (int64 i = 0; i < source_shape.rank(); ++i) {
source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
sharding_tile_dims_stack[i] =
sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
}
for (int64 i = 0; i < target_shape.rank(); ++i) {
target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
}
while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
if (target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) != 1) {
return absl::nullopt;
}
break;
}
int64 s_size = 1;
int64 t_size = 1;
int64 s_partitions = 1;
if (!source_dims_stack.empty()) {
s_size = source_dims_stack.back();
source_dims_stack.pop_back();
s_partitions = sharding_tile_dims_stack.back();
sharding_tile_dims_stack.pop_back();
}
t_size = target_dims_stack.back();
target_dims_stack.pop_back();
if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
target_tile_assignment_dimensions.push_back(1);
continue;
}
if (s_size == t_size) {
// Same dimension.
target_tile_assignment_dimensions.push_back(s_partitions);
} else if (t_size == 1) {
// Trivial dimension added.
target_tile_assignment_dimensions.push_back(1);
source_dims_stack.push_back(s_size);
sharding_tile_dims_stack.push_back(s_partitions);
} else if (s_size == 1) {
// Trivial dimension removed.
if (s_partitions != 1) {
return absl::nullopt;
}
target_dims_stack.push_back(t_size);
} else if (s_size > t_size) {
// Dimension split.
if (s_size % t_size != 0 || s_size % s_partitions != 0) {
return absl::nullopt;
}
if (t_size % s_partitions == 0) {
target_tile_assignment_dimensions.push_back(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(1);
} else if (s_partitions % t_size == 0) {
target_tile_assignment_dimensions.push_back(t_size);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(s_partitions / t_size);
} else {
return absl::nullopt;
}
} else {
// Dimension merge. Also merge the source dimension with the next, and
// process it next time.
if (s_size % s_partitions != 0) {
return absl::nullopt;
}
CHECK(!source_dims_stack.empty());
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
// If the next dimension to combine is sharded, we require that the
// current dimension's shard size to be 1. Otherwise, the new shard
// would be non-contiguous.
return absl::nullopt;
}
source_dims_stack.back() *= s_size;
sharding_tile_dims_stack.back() *= s_partitions;
target_dims_stack.push_back(t_size);
}
}
Array<int64> new_tile_assignment = sharding.tile_assignment();
if (sharding.ReplicateOnLastTileDim()) {
target_tile_assignment_dimensions.push_back(
sharding.tile_assignment().dimensions().back());
}
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding ReverseSharding(const HloSharding& sharding,
absl::Span<const int64> dimensions) {
if (sharding.IsTileMaximal() || dimensions.empty()) {
return sharding;
}
Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
std::vector<int64> original_indices(indices.begin(), indices.end());
for (int64 d : dimensions) {
original_indices[d] =
new_tile_assignment.dim(d) - 1 - original_indices[d];
}
*device = sharding.tile_assignment()(original_indices);
});
return sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
absl::Span<const int64> dims) {
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
// We optimize the tile assignment on the single dimension dim in a way to
// minimize communication among devices caused by the reshard:
// +---+---+ +---+---+ +-+-+-+-+
// | | | | 0 | | | | | |
// | 0 | 1 | +-------+ | | | | |
// | | | reshape on | 1 | reshape on | | | | |
// +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
// | | | | 2 | | | | | |
// | 2 | 3 | +-------+ | | | | |
// | | | | 3 | | | | | |
// +---+---+ +---+---+ +-+-+-+-+
std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
// Handle ignore dimensions.
std::vector<int64> ignore_sizes;
int64 ignore_size = 1;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (absl::c_find(dims, i) == dims.end()) {
int64 size = sharding.tile_assignment().dim(i);
ignore_sizes.push_back(size);
tile_dims[i] = size;
ignore_size *= size;
}
}
using Buckets = std::vector<std::vector<int64>>;
Array<Buckets> buckets(ignore_sizes,
Buckets(sharding.tile_assignment().dim(dim)));
sharding.tile_assignment().Each(
[&](absl::Span<const int64> index, int64 device) {
std::vector<int64> ignore_index;
for (int64 i = 0; i < index.size(); ++i) {
if (absl::c_find(dims, i) == dims.end()) {
ignore_index.push_back(index[i]);
}
}
buckets(ignore_index)[index[dim]].push_back(device);
});
std::vector<int64> devices;
buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
for (auto& bucket : buckets) {
devices.insert(devices.end(), bucket.begin(), bucket.end());
}
});
tile_dims[dim] = devices.size() / ignore_size;
Array<int64> tile_assignment(tile_dims);
tile_assignment.SetValues(devices);
return HloSharding::Tile(tile_assignment);
}
bool ContainsTileSharding(const HloModule& module) {
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->has_sharding() &&
!instruction->sharding().IsTileMaximal()) {
return true;
}
}
}
return false;
}
HloSharding GatherOutputSharding(const HloSharding& index_sharding,
const HloInstruction* hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
std::vector<int64> output_tile_assignment_dims;
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
if (absl::c_binary_search(dnums.offset_dims(), i)) {
output_tile_assignment_dims.push_back(1);
} else {
output_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dim(index_dim));
index_dim++;
}
}
if (index_sharding.ReplicateOnLastTileDim()) {
output_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(output_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(output_tile_assignment_dims);
return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
const HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kGather);
if (output_sharding.IsTileMaximal()) {
return output_sharding;
}
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
std::vector<int64> index_tile_assignment_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
index_tile_assignment_dims.push_back(
output_sharding.tile_assignment().dim(i));
}
}
int64 index_rank = hlo->operand(1)->shape().rank();
// Vector indices sharding is not supported yet.
if (index_rank > index_tile_assignment_dims.size()) {
index_tile_assignment_dims.insert(
index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
}
if (output_sharding.ReplicateOnLastTileDim()) {
index_tile_assignment_dims.push_back(
output_sharding.tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
if (hlo.sharding().IsTileMaximal()) {
return hlo.sharding();
}
const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
std::vector<int64> tile_assignment_dims(hlo.shape().rank());
int64 num_elements = 1;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
num_elements *= hlo.sharding().tile_assignment().dim(i);
} else {
tile_assignment_dims[i] = 1;
}
}
if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
// Output sharding is only on non offset dimensions. We use output sharding
// to shard this gather op directly.
return hlo.sharding();
}
if (num_elements == 1) {
// Output sharding is only on offset dimensions. We do not shard this gather
// op. Return a tile maximal sharding with the first device in output
// sharding tile assignment.
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin());
}
// Output sharding is on both offset and non offset dimensions. We shard the
// gather op only on non offset dimensions.
// For example:
// - the gather op has sharding [2,2]{0,1,2,3},
// - first dimension is non offset dimension,
// - second dimension is offset dimension,
// Then the result sharding will be [2,1]{0,2}.
std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
slice_limits(hlo.shape().rank());
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
} else {
slice_limits[i] = 1;
}
}
Array<int64> tile_assignment =
hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
}
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
const HloInstruction* hlo) {
if (data_sharding.IsTileMaximal()) {
return data_sharding;
}
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
std::vector<int64> index_tile_assignment_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
index_tile_assignment_dims.push_back(
data_sharding.tile_assignment().dim(i));
}
}
if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
index_tile_assignment_dims.push_back(1);
}
if (data_sharding.ReplicateOnLastTileDim()) {
index_tile_assignment_dims.push_back(
data_sharding.tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return data_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
const HloInstruction* hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
std::vector<int64> data_tile_assignment_dims;
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
if (absl::c_binary_search(dnums.update_window_dims(), i)) {
data_tile_assignment_dims.push_back(1);
} else {
data_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dim(index_dim));
index_dim++;
}
}
if (index_sharding.ReplicateOnLastTileDim()) {
data_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(data_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(data_tile_assignment_dims);
return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
const HloInstruction& hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
// Only shard on first "number of scatter_window_dims" dimensions.
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
int64 num_elements = 1;
int64 index_dim = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
num_elements *= index_sharding.tile_assignment().dim(index_dim);
index_dim++;
}
}
if (num_elements == index_sharding.tile_assignment().num_elements()) {
// Index sharding is only on scatter_window_dims. We use this index sharding
// directly.
return index_sharding;
}
// Index sharding is only on update_window_dims. We do not shard this scatter
// op. Return a tile maximal sharding with the first device in index sharding
// tile assignment.
if (num_elements == 1) {
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin());
}
const int64 index_rank = hlo.operand(1)->shape().rank();
std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
for (int64 i = 0; i < index_rank; ++i) {
if (i < index_dim) {
slice_limits[i] = index_sharding.tile_assignment().dim(i);
} else {
slice_limits[i] = 1;
}
}
Array<int64> tile_assignment =
index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
}
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
const HloInstruction& hlo) {
if (data_sharding.IsTileMaximal()) {
return data_sharding;
}
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
const int64 data_rank = hlo.operand(2)->shape().rank();
std::vector<int64> tile_assignment_dims(data_rank, 1LL);
int64 num_elements = 1;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
CHECK_LT(i, data_rank);
tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
num_elements *= data_sharding.tile_assignment().dim(i);
}
}
if (num_elements == data_sharding.tile_assignment().num_elements()) {
// Data sharding is only on scatter_window_dims. We use this data sharding
// directly.
return data_sharding;
}
if (num_elements == 1) {
// Data sharding is only on update_window_dims. We do not shard this
// scatter op. Return a tile maximal sharding with the first device in
// data sharding tile assignment.
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin());
}
// Data sharding is on both update_window_dims and scatter_window_dims. We
// shard the scatter op only on scatter_window_dims. For example:
// - the scatter data has sharding [2,2]{0,1,2,3},
// - first dimension is scatter_window_dims,
// - second dimension is update_window_dims,
// Then the result sharding will be [2,1]{0,2}.
std::vector<int64> slice_starts(data_rank, 0LL);
Array<int64> tile_assignment =
data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
return HloSharding::Tile(tile_assignment);
}
namespace {
// If partitioning in the operand only happens in dimensions in passthrough
// dimensions (offset dimensions in the gather output (or scatter update) that
// have the same size as the operand), returns the corresponding output (or
// update) sharding by passing through the input sharding.
absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
const Shape& operand_shape, const HloSharding& operand_sharding,
const Shape& update_or_gather_shape,
absl::Span<const int64> collapsed_or_inserted_dims,
absl::Span<const int64> index_map,
absl::Span<const int64> offset_or_window_dims,
absl::Span<const int64> slice_size) {
if (operand_sharding.IsTileMaximal()) {
return operand_sharding;
}
std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
int64 collapsed = 0;
for (int64 i = 0; i < operand_shape.rank(); ++i) {
int64 dim_partitions = operand_sharding.tile_assignment().dim(i);
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(index_map, i)) {
if (dim_partitions > 1) {
return absl::nullopt;
}
collapsed++;
continue;
}
if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
return absl::nullopt;
}
int64 offset_dim = offset_or_window_dims[i - collapsed];
if (i - collapsed > 0 &&
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
// Output offsets are transposed, we do not support this case.
return absl::nullopt;
}
passthrough_tile[offset_dim] = dim_partitions;
}
if (operand_sharding.ReplicateOnLastTileDim()) {
passthrough_tile.push_back(
operand_sharding.tile_assignment().dimensions().back());
}
Array<int64> tile_assignment = operand_sharding.tile_assignment();
tile_assignment.Reshape(passthrough_tile);
return operand_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
}
// Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
absl::Span<const int64> collapsed_or_inserted_dims,
absl::Span<const int64> index_map,
absl::Span<const int64> offset_or_window_dims,
absl::Span<const int64> slice_size) {
if (update_or_gather_sharding.IsTileMaximal()) {
return update_or_gather_sharding;
}
std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
int64 collapsed = 0;
for (int64 i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(index_map, i)) {
collapsed++;
continue;
}
int64 offset_dim = offset_or_window_dims[i - collapsed];
int64 dim_partitions =
update_or_gather_sharding.tile_assignment().dim(offset_dim);
if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
return absl::nullopt;
}
if (i - collapsed > 0 &&
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
// Output offsets are transposed, we do not support this case.
return absl::nullopt;
}
passthrough_tile[i] = dim_partitions;
}
if (update_or_gather_sharding.ReplicateOnLastTileDim()) {
passthrough_tile.push_back(
update_or_gather_sharding.tile_assignment().dimensions().back());
}
Array<int64> tile_assignment = update_or_gather_sharding.tile_assignment();
if (tile_assignment.num_elements() != Product(passthrough_tile)) {
return absl::nullopt;
}
tile_assignment.Reshape(passthrough_tile);
return update_or_gather_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
}
} // namespace
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
const HloSharding& data_operand_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.gather_dimension_numbers();
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
return PassthroughOperandToGatherOutputOrScatterUpdate(
hlo.operand(0)->shape(), data_operand_sharding, hlo.shape(),
collapsed_slice_dims, start_index_map, offset_dims,
hlo.gather_slice_sizes());
}
absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.gather_dimension_numbers();
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
return PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
start_index_map, offset_dims, hlo.gather_slice_sizes());
}
absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
const HloSharding& update_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.scatter_dimension_numbers();
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
dnums.inserted_window_dims().end());
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64> slice_size(hlo.shape().rank(), 1);
int64 num_update_window_dims = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = hlo.operand(2)->shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
return PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.shape(), update_sharding, inserted_window_dims,
scatter_dims_to_operand_dims, update_window_dims, slice_size);
}
absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.scatter_dimension_numbers();
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
dnums.inserted_window_dims().end());
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64> slice_size(hlo.shape().rank(), 1);
int64 num_update_window_dims = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = hlo.operand(2)->shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
return PassthroughOperandToGatherOutputOrScatterUpdate(
hlo.shape(), output_sharding, hlo.operand(2)->shape(),
inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
slice_size);
}
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(
const HloScatterInstruction& scatter) {
auto computation = scatter.to_apply();
// We only handle computations with 2 parameters and only 1 calculation.
if (computation->instruction_count() != 3) {
return Status(
tensorflow::error::Code::INVALID_ARGUMENT,
"Expected scatter reduce computation with 2 parameters and only 1 "
"calculation");
}
auto root_instruction = computation->root_instruction();
if (root_instruction->opcode() == HloOpcode::kAdd ||
root_instruction->opcode() == HloOpcode::kOr) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMultiply ||
root_instruction->opcode() == HloOpcode::kAnd) {
return std::make_pair(HloInstruction::CreateConstant(
LiteralUtil::One(scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMaximum) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMinimum) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
scatter.shape().element_type())),
root_instruction->opcode());
}
return Status(tensorflow::error::Code::INVALID_ARGUMENT,
"Expected scatter reduce computation which is "
"add/or/multiply/add/min/max");
}
namespace {
void DevicesForShardingInternal(
const HloSharding& sharding,
const absl::flat_hash_set<int64>& available_devices,
absl::flat_hash_set<int64>* used) {
if (sharding.IsTuple()) {
for (const auto& subsharding : sharding.tuple_elements()) {
DevicesForShardingInternal(subsharding, available_devices, used);
}
return;
}
if (sharding.IsReplicated()) {
for (int64 device : available_devices) {
if (!HloSharding::IsReservedDevice(device)) {
used->insert(device);
}
}
return;
}
DCHECK(std::all_of(
sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
[&](int64 device) { return available_devices.contains(device); }));
sharding.tile_assignment().Each([&](absl::Span<const int64> /*indices*/,
int64 device) { used->insert(device); });
}
} // namespace
std::vector<int64> DevicesForSharding(
const HloSharding& sharding, const std::vector<int64>& available_devices) {
absl::flat_hash_set<int64> available_set;
for (int64 device : available_devices) {
available_set.insert(device);
}
absl::flat_hash_set<int64> used_set;
DevicesForShardingInternal(sharding, available_set, &used_set);
std::vector<int64> devices;
for (int64 device : available_devices) {
if (used_set.contains(device)) {
devices.push_back(device);
}
}
return devices;
}
HloSharding PartiallyReplicateTiledShardingOnDims(
const HloSharding& sharding, const std::vector<int64>& dims_to_replicate) {
if (sharding.IsTileMaximal()) {
return sharding;
}
int64 group_count = 1;
for (int64 dim : dims_to_replicate) {
if (sharding.ReplicateOnLastTileDim()) {
CHECK_LT(dim, sharding.tile_assignment().num_dimensions());
}
group_count *= sharding.tile_assignment().dim(dim);
}
if (group_count == 1) {
return sharding;
}
if (group_count == sharding.NumTiles()) {
return HloSharding::Replicate();
}
std::vector<int64> dim_permutation(
sharding.tile_assignment().num_dimensions());
std::iota(dim_permutation.begin(), dim_permutation.end(), 0);
absl::c_sort(dim_permutation, [&](const int64 a, const int64 b) {
return absl::c_linear_search(dims_to_replicate, a) <
absl::c_linear_search(dims_to_replicate, b);
});
auto transposed = TransposeSharding(sharding, dim_permutation);
auto new_tile = transposed.tile_assignment();
std::vector<int64> new_tile_shape(
sharding.tile_assignment().dimensions().begin(),
sharding.tile_assignment().dimensions().end());
for (int64 dim : dims_to_replicate) {
new_tile_shape[dim] = 1;
}
if (sharding.ReplicateOnLastTileDim()) {
new_tile_shape.back() *= group_count;
} else {
new_tile_shape.push_back(group_count);
}
new_tile.Reshape(new_tile_shape);
return HloSharding::PartialTile(new_tile);
}
HloSharding RemoveShapeDimensions(const HloSharding& sharding,
const std::vector<int64>& dims_to_remove) {
if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
return sharding;
}
std::vector<int64> new_tile_shape;
new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
dims_to_remove.size());
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (absl::c_linear_search(dims_to_remove, i)) {
CHECK_EQ(sharding.tile_assignment().dim(i), 1);
} else {
new_tile_shape.push_back(sharding.tile_assignment().dim(i));
}
}
auto new_tile = sharding.tile_assignment();
new_tile.Reshape(new_tile_shape);
return sharding.ReplicateOnLastTileDim() ? HloSharding::PartialTile(new_tile)
: HloSharding::Tile(new_tile);
}
absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
const HloSharding& source, absl::Span<int64 const> src_to_tgt,
absl::Span<int64 const> tgt_to_src) {
if (source.IsTileMaximal()) {
return source;
}
if (source.ReplicateOnLastTileDim() &&
src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
std::vector<int64> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
new_src_to_tgt.push_back(tgt_to_src.size());
std::vector<int64> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
new_tgt_to_src.push_back(src_to_tgt.size());
return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
new_tgt_to_src);
}
std::vector<int64> tgt_dims_skipping_new(tgt_to_src.size(), -1);
int64 skipped_tgt_dims = 0;
for (int64 i = 0; i < tgt_to_src.size(); ++i) {
if (tgt_to_src[i] < 0) {
skipped_tgt_dims++;
} else {
tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
}
}
int64 skipped_src_dims = absl::c_count(src_to_tgt, -1);
std::vector<int64> perm(src_to_tgt.size());
for (int64 i = 0; i < src_to_tgt.size(); ++i) {
if (src_to_tgt[i] < 0) {
if (source.tile_assignment().dim(i) > 1) {
return absl::nullopt;
}
perm[src_to_tgt.size() - skipped_src_dims] = i;
skipped_src_dims--;
} else {
perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
}
}
auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
auto reshape_tiles = tgt_sharding.tile_assignment();
std::vector<int64> tgt_tiles(tgt_to_src.size(), 1);
for (int64 i = 0; i < tgt_tiles.size(); ++i) {
if (tgt_to_src[i] >= 0) {
tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]);
}
}
reshape_tiles.Reshape(tgt_tiles);
return source.ReplicateOnLastTileDim()
? HloSharding::PartialTile(reshape_tiles)
: HloSharding::Tile(reshape_tiles);
}
} // namespace hlo_sharding_util
} // namespace xla