Yuanzhong Xu 5c77174291 [XLA:SPMD] More cases of reverse sharding
- Improve sharding propagation to reverse the tile assignment
- Use reshard (collective permute) to fix mismatch operand sharding
- Use halo exchange to fix uneven partitioning

PiperOrigin-RevId: 313672162
Change-Id: I0816de794a0c18a0173889ed8cd638baecf389e9
2020-05-28 15:44:17 -07:00

593 lines
23 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include <map>
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace hlo_sharding_util {
absl::optional<int64> SelectDominantDevice(
const std::map<int64, int64>& device_map, int64* top_count) {
int64 device = 0;
int64 count = 0;
for (auto& it : device_map) {
if (it.second > count) {
count = it.second;
device = it.first;
}
}
if (top_count != nullptr) {
*top_count = count;
}
return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
}
Status AssignComputationDevice(HloComputation* computation, int64 device) {
VLOG(4) << "Assigning device " << device << " to " << computation->name()
<< " computation";
for (HloInstruction* instruction : computation->instructions()) {
if (!instruction->has_sharding()) {
VLOG(4) << "Assigning device " << device << " to " << instruction->name();
instruction->set_device_sharding(device);
}
}
return Status::OK();
}
absl::optional<int64> GetMostOccurringDevice(
absl::Span<HloInstruction* const> instructions) {
std::map<int64, int64> device_map;
for (HloInstruction* instruction : instructions) {
if (instruction->has_sharding()) {
for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
// The UsedDevices() API returns a map<device, occurrence_count>.
device_map[it.first] += it.second;
}
}
}
return SelectDominantDevice(device_map, nullptr);
}
StatusOr<absl::optional<int64>> GetDominantDevice(
absl::Span<HloComputation* const> computations, double dominant_factor) {
int64 instruction_count = 0;
std::map<int64, int64> device_map;
for (HloComputation* computation : computations) {
for (HloInstruction* instruction : computation->instructions()) {
int64 count = 1;
if (instruction->has_sharding()) {
for (auto& it : instruction->sharding().UsedDevices(&count)) {
// The UsedDevices() API returns a map<device, occurrence_count>.
device_map[it.first] += it.second;
}
}
instruction_count += count;
}
}
int64 count;
absl::optional<int64> device = SelectDominantDevice(device_map, &count);
absl::optional<int64> dominant_device;
if (device) {
double factor =
static_cast<double>(count) / static_cast<double>(instruction_count);
if (factor >= dominant_factor) {
dominant_device = device;
}
}
return dominant_device;
}
HloSharding TransposeSharding(const HloSharding& sharding,
const std::vector<int64>& dimensions) {
if (sharding.IsTileMaximal()) {
return sharding;
}
const int64 rank = dimensions.size();
std::vector<int64> tile_assignment_dim(rank);
for (int64 i = 0; i < rank; ++i) {
tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]);
}
Array<int64> tile_assignment = sharding.tile_assignment();
tile_assignment.Reshape(tile_assignment_dim);
tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
std::vector<int64> src_indices(indices.size(), -1);
for (int64 i = 0; i < indices.size(); ++i) {
src_indices[dimensions[i]] = indices[i];
}
*value = sharding.tile_assignment()(src_indices);
});
return HloSharding::Tile(tile_assignment);
}
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
const Shape& target_shape,
const HloSharding& sharding) {
if (sharding.IsTileMaximal()) {
return sharding;
}
// In case of a tiled sharding the reshaped sharding will be a valid if the
// reshape is composed from the following operations:
// * Adding or removing dimensions with size 1.
// * Merging consecutive dimensions where only the most major is sharded.
// * Splitting a dimension to consecutive dimensions.
// * Any reshaping of unsharded dimensions.
// Note that merge and split can happen consecutively on the same dimension,
// e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
// gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
// to make supporting such cases easy.
const Shape tile_shape = sharding.TileShape(source_shape);
std::vector<int64> target_tile_assignment_dimensions;
std::vector<int64> source_dims_stack(source_shape.rank());
std::vector<int64> target_dims_stack(target_shape.rank());
std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
for (int64 i = 0; i < source_shape.rank(); ++i) {
source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
sharding_tile_dims_stack[i] =
sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
}
for (int64 i = 0; i < target_shape.rank(); ++i) {
target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
}
while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
if (target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) != 1) {
return absl::nullopt;
}
break;
}
int64 s_size = 1;
int64 t_size = 1;
int64 s_partitions = 1;
if (!source_dims_stack.empty()) {
s_size = source_dims_stack.back();
source_dims_stack.pop_back();
s_partitions = sharding_tile_dims_stack.back();
sharding_tile_dims_stack.pop_back();
}
t_size = target_dims_stack.back();
target_dims_stack.pop_back();
if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
target_tile_assignment_dimensions.push_back(1);
continue;
}
if (s_size == t_size) {
// Same dimension.
target_tile_assignment_dimensions.push_back(s_partitions);
} else if (t_size == 1) {
// Trivial dimension added.
target_tile_assignment_dimensions.push_back(1);
source_dims_stack.push_back(s_size);
sharding_tile_dims_stack.push_back(s_partitions);
} else if (s_size == 1) {
// Trivial dimension removed.
if (s_partitions != 1) {
return absl::nullopt;
}
target_dims_stack.push_back(t_size);
} else if (s_size > t_size) {
// Dimension split.
if (s_size % t_size != 0 || t_size % s_partitions != 0) {
return absl::nullopt;
}
target_tile_assignment_dimensions.push_back(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(1);
} else {
// Dimension merge. Also merge the source dimension with the next, and
// process it next time.
if (s_size % s_partitions != 0) {
return absl::nullopt;
}
CHECK(!source_dims_stack.empty());
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
// If the next dimension to combine is sharded, we require that the
// current dimension's shard size to be 1. Otherwise, the new shard
// would be non-contiguous.
return absl::nullopt;
}
source_dims_stack.back() *= s_size;
sharding_tile_dims_stack.back() *= s_partitions;
target_dims_stack.push_back(t_size);
}
}
Array<int64> new_tile_assignment = sharding.tile_assignment();
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ReverseSharding(const HloSharding& sharding,
absl::Span<const int64> dimensions) {
if (sharding.IsTileMaximal() || dimensions.empty()) {
return sharding;
}
Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
std::vector<int64> original_indices(indices.begin(), indices.end());
for (int64 d : dimensions) {
original_indices[d] =
new_tile_assignment.dim(d) - 1 - original_indices[d];
}
*device = sharding.tile_assignment()(original_indices);
});
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
absl::Span<const int64> dims) {
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
// We optimize the tile assignment on the single dimension dim in a way to
// minimize communication among devices caused by the reshard:
// +---+---+ +---+---+ +-+-+-+-+
// | | | | 0 | | | | | |
// | 0 | 1 | +-------+ | | | | |
// | | | reshape on | 1 | reshape on | | | | |
// +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
// | | | | 2 | | | | | |
// | 2 | 3 | +-------+ | | | | |
// | | | | 3 | | | | | |
// +---+---+ +---+---+ +-+-+-+-+
std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
// Handle ignore dimensions.
std::vector<int64> ignore_sizes;
int64 ignore_size = 1;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (absl::c_find(dims, i) == dims.end()) {
int64 size = sharding.tile_assignment().dim(i);
ignore_sizes.push_back(size);
tile_dims[i] = size;
ignore_size *= size;
}
}
using Buckets = std::vector<std::vector<int64>>;
Array<Buckets> buckets(ignore_sizes,
Buckets(sharding.tile_assignment().dim(dim)));
sharding.tile_assignment().Each(
[&](absl::Span<const int64> index, int64 device) {
std::vector<int64> ignore_index;
for (int64 i = 0; i < index.size(); ++i) {
if (absl::c_find(dims, i) == dims.end()) {
ignore_index.push_back(index[i]);
}
}
buckets(ignore_index)[index[dim]].push_back(device);
});
std::vector<int64> devices;
buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
for (auto& bucket : buckets) {
devices.insert(devices.end(), bucket.begin(), bucket.end());
}
});
tile_dims[dim] = devices.size() / ignore_size;
Array<int64> tile_assignment(tile_dims);
tile_assignment.SetValues(devices);
return HloSharding::Tile(tile_assignment);
}
bool ContainsTileSharding(const HloModule& module) {
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->has_sharding() &&
!instruction->sharding().IsTileMaximal()) {
return true;
}
}
}
return false;
}
HloSharding GatherOutputSharding(const HloSharding& index_sharding,
const HloInstruction* hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
std::vector<int64> output_tile_assignment_dims;
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
if (absl::c_binary_search(dnums.offset_dims(), i)) {
output_tile_assignment_dims.push_back(1);
} else {
output_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dim(index_dim));
index_dim++;
}
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
new_tile_assignment.Reshape(output_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
const HloInstruction* hlo) {
if (output_sharding.IsTileMaximal()) {
return output_sharding;
}
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
std::vector<int64> index_tile_assignment_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
index_tile_assignment_dims.push_back(
output_sharding.tile_assignment().dim(i));
}
}
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
if (hlo.sharding().IsTileMaximal()) {
return hlo.sharding();
}
const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
std::vector<int64> tile_assignment_dims(hlo.shape().rank());
int64 num_elements = 1;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
num_elements *= hlo.sharding().tile_assignment().dim(i);
} else {
tile_assignment_dims[i] = 1;
}
}
if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
// Output sharding is only on non offset dimensions. We use output sharding
// to shard this gather op directly.
return hlo.sharding();
}
if (num_elements == 1) {
// Output sharding is only on offset dimensions. We do not shard this gather
// op. Return a tile maximal sharding with the first device in output
// sharding tile assignment.
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin());
}
// Output sharding is on both offset and non offset dimensions. We shard the
// gather op only on non offset dimensions.
// For example:
// - the gather op has sharding [2,2]{0,1,2,3},
// - first dimension is non offset dimension,
// - second dimension is offset dimension,
// Then the result sharding will be [2,1]{0,2}.
std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
slice_limits(hlo.shape().rank());
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
} else {
slice_limits[i] = 1;
}
}
Array<int64> tile_assignment =
hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
}
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
const HloInstruction* hlo) {
if (data_sharding.IsTileMaximal()) {
return data_sharding;
}
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
std::vector<int64> index_tile_assignment_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
index_tile_assignment_dims.push_back(
data_sharding.tile_assignment().dim(i));
}
}
if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
index_tile_assignment_dims.push_back(1);
}
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
const HloInstruction* hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
std::vector<int64> data_tile_assignment_dims;
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
if (absl::c_binary_search(dnums.update_window_dims(), i)) {
data_tile_assignment_dims.push_back(1);
} else {
data_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dim(index_dim));
index_dim++;
}
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
new_tile_assignment.Reshape(data_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
const HloInstruction& hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
// Only shard on first "number of scatter_window_dims" dimensions.
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
int64 num_elements = 1;
int64 index_dim = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
num_elements *= index_sharding.tile_assignment().dim(index_dim);
index_dim++;
}
}
if (num_elements == index_sharding.tile_assignment().num_elements()) {
// Index sharding is only on scatter_window_dims. We use this index sharding
// directly.
return index_sharding;
}
// Index sharding is only on update_window_dims. We do not shard this scatter
// op. Return a tile maximal sharding with the first device in index sharding
// tile assignment.
if (num_elements == 1) {
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin());
}
const int64 index_rank = hlo.operand(1)->shape().rank();
std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
for (int64 i = 0; i < index_rank; ++i) {
if (i < index_dim) {
slice_limits[i] = index_sharding.tile_assignment().dim(i);
} else {
slice_limits[i] = 1;
}
}
Array<int64> tile_assignment =
index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
}
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
const HloInstruction& hlo) {
if (data_sharding.IsTileMaximal()) {
return data_sharding;
}
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
const int64 data_rank = hlo.operand(2)->shape().rank();
std::vector<int64> tile_assignment_dims(data_rank, 1LL);
int64 num_elements = 1;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
CHECK_LT(i, data_rank);
tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
num_elements *= data_sharding.tile_assignment().dim(i);
}
}
if (num_elements == data_sharding.tile_assignment().num_elements()) {
// Data sharding is only on scatter_window_dims. We use this data sharding
// directly.
return data_sharding;
}
if (num_elements == 1) {
// Data sharding is only on update_window_dims. We do not shard this
// scatter op. Return a tile maximal sharding with the first device in
// data sharding tile assignment.
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin());
}
// Data sharding is on both update_window_dims and scatter_window_dims. We
// shard the scatter op only on scatter_window_dims. For example:
// - the scatter data has sharding [2,2]{0,1,2,3},
// - first dimension is scatter_window_dims,
// - second dimension is update_window_dims,
// Then the result sharding will be [2,1]{0,2}.
std::vector<int64> slice_starts(data_rank, 0LL);
Array<int64> tile_assignment =
data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
return HloSharding::Tile(tile_assignment);
}
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(
const HloScatterInstruction& scatter) {
auto computation = scatter.to_apply();
// We only handle computations with 2 parameters and only 1 calculation.
if (computation->instruction_count() != 3) {
return Status(
tensorflow::error::Code::INVALID_ARGUMENT,
"Expected scatter reduce computation with 2 parameters and only 1 "
"calculation");
}
auto root_instruction = computation->root_instruction();
if (root_instruction->opcode() == HloOpcode::kAdd ||
root_instruction->opcode() == HloOpcode::kOr) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMultiply ||
root_instruction->opcode() == HloOpcode::kAnd) {
return std::make_pair(HloInstruction::CreateConstant(
LiteralUtil::One(scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMaximum) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMinimum) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
scatter.shape().element_type())),
root_instruction->opcode());
}
return Status(tensorflow::error::Code::INVALID_ARGUMENT,
"Expected scatter reduce computation which is "
"add/or/multiply/add/min/max");
}
std::vector<int64> DevicesForSharding(
const HloSharding& sharding, const std::vector<int64>& available_devices) {
std::vector<int64> devices;
if (sharding.IsReplicated()) {
for (int64 d : available_devices) {
if (!HloSharding::IsReservedDevice(d)) {
devices.push_back(d);
}
}
return devices;
}
for (int64 i : available_devices) {
if (sharding.UsesDevice(i)) {
devices.push_back(i);
}
}
DCHECK(std::all_of(sharding.tile_assignment().begin(),
sharding.tile_assignment().end(), [&](int64 device) {
return std::find(available_devices.begin(),
available_devices.end(),
device) != available_devices.end();
}));
return devices;
}
} // namespace hlo_sharding_util
} // namespace xla