[XLA] Move SPMD partitioner to third_party

This change moves the work on SPMD partitioning that the XLA team has been working on in the past 12 months.

PiperOrigin-RevId: 311367525
Change-Id: If174527128c222c53736dc8db2ef1ea4177fb476
This commit is contained in:
Yuanzhong Xu 2020-05-13 11:20:11 -07:00 committed by TensorFlower Gardener
parent 59239ab499
commit d45abae4e9
10 changed files with 10195 additions and 0 deletions

View File

@ -460,6 +460,37 @@ cc_library(
],
)
cc_library(
name = "hlo_sharding_util",
srcs = [
"hlo_sharding_util.cc",
],
hdrs = [
"hlo_sharding_util.h",
],
deps = [
":hlo",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "hlo_sharding_util_test",
srcs = [
"hlo_sharding_util_test.cc",
],
deps = [
":hlo_sharding_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
tf_cc_test(
name = "dynamic_parameter_binding_test",
srcs = ["dynamic_parameter_binding_test.cc"],

View File

@ -0,0 +1,574 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include <map>
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace hlo_sharding_util {
absl::optional<int64> SelectDominantDevice(
const std::map<int64, int64>& device_map, int64* top_count) {
int64 device = 0;
int64 count = 0;
for (auto& it : device_map) {
if (it.second > count) {
count = it.second;
device = it.first;
}
}
if (top_count != nullptr) {
*top_count = count;
}
return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
}
Status AssignComputationDevice(HloComputation* computation, int64 device) {
VLOG(4) << "Assigning device " << device << " to " << computation->name()
<< " computation";
for (HloInstruction* instruction : computation->instructions()) {
if (!instruction->has_sharding()) {
VLOG(4) << "Assigning device " << device << " to " << instruction->name();
instruction->set_device_sharding(device);
}
}
return Status::OK();
}
absl::optional<int64> GetMostOccurringDevice(
absl::Span<HloInstruction* const> instructions) {
std::map<int64, int64> device_map;
for (HloInstruction* instruction : instructions) {
if (instruction->has_sharding()) {
for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
// The UsedDevices() API returns a map<device, occurrence_count>.
device_map[it.first] += it.second;
}
}
}
return SelectDominantDevice(device_map, nullptr);
}
StatusOr<absl::optional<int64>> GetDominantDevice(
absl::Span<HloComputation* const> computations, double dominant_factor) {
int64 instruction_count = 0;
std::map<int64, int64> device_map;
for (HloComputation* computation : computations) {
for (HloInstruction* instruction : computation->instructions()) {
int64 count = 1;
if (instruction->has_sharding()) {
for (auto& it : instruction->sharding().UsedDevices(&count)) {
// The UsedDevices() API returns a map<device, occurrence_count>.
device_map[it.first] += it.second;
}
}
instruction_count += count;
}
}
int64 count;
absl::optional<int64> device = SelectDominantDevice(device_map, &count);
absl::optional<int64> dominant_device;
if (device) {
double factor =
static_cast<double>(count) / static_cast<double>(instruction_count);
if (factor >= dominant_factor) {
dominant_device = device;
}
}
return dominant_device;
}
HloSharding TransposeSharding(const HloSharding& sharding,
const std::vector<int64>& dimensions) {
if (sharding.IsTileMaximal()) {
return sharding;
}
const int64 rank = dimensions.size();
std::vector<int64> tile_assignment_dim(rank);
for (int64 i = 0; i < rank; ++i) {
tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]);
}
Array<int64> tile_assignment = sharding.tile_assignment();
tile_assignment.Reshape(tile_assignment_dim);
tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
std::vector<int64> src_indices(indices.size(), -1);
for (int64 i = 0; i < indices.size(); ++i) {
src_indices[dimensions[i]] = indices[i];
}
*value = sharding.tile_assignment()(src_indices);
});
return HloSharding::Tile(tile_assignment);
}
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
const Shape& target_shape,
const HloSharding& sharding) {
if (sharding.IsTileMaximal()) {
return sharding;
}
// In case of a tiled sharding the reshaped sharding will be a valid if the
// reshape is composed from the following operations:
// * Adding or removing dimensions with size 1.
// * Merging consecutive dimensions where only the most major is sharded.
// * Splitting a dimension to consecutive dimensions.
// * Any reshaping of unsharded dimensions.
// Note that merge and split can happen consecutively on the same dimension,
// e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
// gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
// to make supporting such cases easy.
const Shape tile_shape = sharding.TileShape(source_shape);
std::vector<int64> target_tile_assignment_dimensions;
std::vector<int64> source_dims_stack(source_shape.rank());
std::vector<int64> target_dims_stack(target_shape.rank());
std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
for (int64 i = 0; i < source_shape.rank(); ++i) {
source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
sharding_tile_dims_stack[i] =
sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
}
for (int64 i = 0; i < target_shape.rank(); ++i) {
target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
}
while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
if (target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) != 1) {
return absl::nullopt;
}
break;
}
int64 s_size = 1;
int64 t_size = 1;
int64 s_partitions = 1;
if (!source_dims_stack.empty()) {
s_size = source_dims_stack.back();
source_dims_stack.pop_back();
s_partitions = sharding_tile_dims_stack.back();
sharding_tile_dims_stack.pop_back();
}
t_size = target_dims_stack.back();
target_dims_stack.pop_back();
if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
target_tile_assignment_dimensions.push_back(1);
continue;
}
if (s_size == t_size) {
// Same dimension.
target_tile_assignment_dimensions.push_back(s_partitions);
} else if (t_size == 1) {
// Trivial dimension added.
target_tile_assignment_dimensions.push_back(1);
source_dims_stack.push_back(s_size);
sharding_tile_dims_stack.push_back(s_partitions);
} else if (s_size == 1) {
// Trivial dimension removed.
if (s_partitions != 1) {
return absl::nullopt;
}
target_dims_stack.push_back(t_size);
} else if (s_size > t_size) {
// Dimension split.
if (s_size % t_size != 0 || t_size % s_partitions != 0) {
return absl::nullopt;
}
target_tile_assignment_dimensions.push_back(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(1);
} else {
// Dimension merge. Also merge the source dimension with the next, and
// process it next time.
if (s_size % s_partitions != 0) {
return absl::nullopt;
}
CHECK(!source_dims_stack.empty());
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
// If the next dimension to combine is sharded, we require that the
// current dimension's shard size to be 1. Otherwise, the new shard
// would be non-contiguous.
return absl::nullopt;
}
source_dims_stack.back() *= s_size;
sharding_tile_dims_stack.back() *= s_partitions;
target_dims_stack.push_back(t_size);
}
}
Array<int64> new_tile_assignment = sharding.tile_assignment();
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
absl::Span<const int64> dims) {
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
// We optimize the tile assignment on the single dimension dim in a way to
// minimize communication among devices caused by the reshard:
// +---+---+ +---+---+ +-+-+-+-+
// | | | | 0 | | | | | |
// | 0 | 1 | +-------+ | | | | |
// | | | reshape on | 1 | reshape on | | | | |
// +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
// | | | | 2 | | | | | |
// | 2 | 3 | +-------+ | | | | |
// | | | | 3 | | | | | |
// +---+---+ +---+---+ +-+-+-+-+
std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
// Handle ignore dimensions.
std::vector<int64> ignore_sizes;
int64 ignore_size = 1;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (absl::c_find(dims, i) == dims.end()) {
int64 size = sharding.tile_assignment().dim(i);
ignore_sizes.push_back(size);
tile_dims[i] = size;
ignore_size *= size;
}
}
using Buckets = std::vector<std::vector<int64>>;
Array<Buckets> buckets(ignore_sizes,
Buckets(sharding.tile_assignment().dim(dim)));
sharding.tile_assignment().Each(
[&](absl::Span<const int64> index, int64 device) {
std::vector<int64> ignore_index;
for (int64 i = 0; i < index.size(); ++i) {
if (absl::c_find(dims, i) == dims.end()) {
ignore_index.push_back(index[i]);
}
}
buckets(ignore_index)[index[dim]].push_back(device);
});
std::vector<int64> devices;
buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
for (auto& bucket : buckets) {
devices.insert(devices.end(), bucket.begin(), bucket.end());
}
});
tile_dims[dim] = devices.size() / ignore_size;
Array<int64> tile_assignment(tile_dims);
tile_assignment.SetValues(devices);
return HloSharding::Tile(tile_assignment);
}
bool ContainsTileSharding(const HloModule& module) {
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->has_sharding() &&
!instruction->sharding().IsTileMaximal()) {
return true;
}
}
}
return false;
}
HloSharding GatherOutputSharding(const HloSharding& index_sharding,
const HloInstruction* hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
std::vector<int64> output_tile_assignment_dims;
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
if (absl::c_binary_search(dnums.offset_dims(), i)) {
output_tile_assignment_dims.push_back(1);
} else {
output_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dim(index_dim));
index_dim++;
}
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
new_tile_assignment.Reshape(output_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
const HloInstruction* hlo) {
if (output_sharding.IsTileMaximal()) {
return output_sharding;
}
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
std::vector<int64> index_tile_assignment_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
index_tile_assignment_dims.push_back(
output_sharding.tile_assignment().dim(i));
}
}
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
if (hlo.sharding().IsTileMaximal()) {
return hlo.sharding();
}
const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
std::vector<int64> tile_assignment_dims(hlo.shape().rank());
int64 num_elements = 1;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
num_elements *= hlo.sharding().tile_assignment().dim(i);
} else {
tile_assignment_dims[i] = 1;
}
}
if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
// Output sharding is only on non offset dimensions. We use output sharding
// to shard this gather op directly.
return hlo.sharding();
}
if (num_elements == 1) {
// Output sharding is only on offset dimensions. We do not shard this gather
// op. Return a tile maximal sharding with the first device in output
// sharding tile assignment.
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin());
}
// Output sharding is on both offset and non offset dimensions. We shard the
// gather op only on non offset dimensions.
// For example:
// - the gather op has sharding [2,2]{0,1,2,3},
// - first dimension is non offset dimension,
// - second dimension is offset dimension,
// Then the result sharding will be [2,1]{0,2}.
std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
slice_limits(hlo.shape().rank());
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
} else {
slice_limits[i] = 1;
}
}
Array<int64> tile_assignment =
hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
}
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
const HloInstruction* hlo) {
if (data_sharding.IsTileMaximal()) {
return data_sharding;
}
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
std::vector<int64> index_tile_assignment_dims;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
index_tile_assignment_dims.push_back(
data_sharding.tile_assignment().dim(i));
}
}
if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
index_tile_assignment_dims.push_back(1);
}
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
const HloInstruction* hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
std::vector<int64> data_tile_assignment_dims;
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
if (absl::c_binary_search(dnums.update_window_dims(), i)) {
data_tile_assignment_dims.push_back(1);
} else {
data_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dim(index_dim));
index_dim++;
}
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
new_tile_assignment.Reshape(data_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
const HloInstruction& hlo) {
if (index_sharding.IsTileMaximal()) {
return index_sharding;
}
// Only shard on first "number of scatter_window_dims" dimensions.
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
int64 num_elements = 1;
int64 index_dim = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
num_elements *= index_sharding.tile_assignment().dim(index_dim);
index_dim++;
}
}
if (num_elements == index_sharding.tile_assignment().num_elements()) {
// Index sharding is only on scatter_window_dims. We use this index sharding
// directly.
return index_sharding;
}
// Index sharding is only on update_window_dims. We do not shard this scatter
// op. Return a tile maximal sharding with the first device in index sharding
// tile assignment.
if (num_elements == 1) {
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin());
}
const int64 index_rank = hlo.operand(1)->shape().rank();
std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
for (int64 i = 0; i < index_rank; ++i) {
if (i < index_dim) {
slice_limits[i] = index_sharding.tile_assignment().dim(i);
} else {
slice_limits[i] = 1;
}
}
Array<int64> tile_assignment =
index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
return HloSharding::Tile(tile_assignment);
}
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
const HloInstruction& hlo) {
if (data_sharding.IsTileMaximal()) {
return data_sharding;
}
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
const int64 data_rank = hlo.operand(2)->shape().rank();
std::vector<int64> tile_assignment_dims(data_rank, 1LL);
int64 num_elements = 1;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
CHECK_LT(i, data_rank);
tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
num_elements *= data_sharding.tile_assignment().dim(i);
}
}
if (num_elements == data_sharding.tile_assignment().num_elements()) {
// Data sharding is only on scatter_window_dims. We use this data sharding
// directly.
return data_sharding;
}
if (num_elements == 1) {
// Data sharding is only on update_window_dims. We do not shard this
// scatter op. Return a tile maximal sharding with the first device in
// data sharding tile assignment.
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin());
}
// Data sharding is on both update_window_dims and scatter_window_dims. We
// shard the scatter op only on scatter_window_dims. For example:
// - the scatter data has sharding [2,2]{0,1,2,3},
// - first dimension is scatter_window_dims,
// - second dimension is update_window_dims,
// Then the result sharding will be [2,1]{0,2}.
std::vector<int64> slice_starts(data_rank, 0LL);
Array<int64> tile_assignment =
data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
return HloSharding::Tile(tile_assignment);
}
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(
const HloScatterInstruction& scatter) {
auto computation = scatter.to_apply();
// We only handle computations with 2 parameters and only 1 calculation.
if (computation->instruction_count() != 3) {
return Status(
tensorflow::error::Code::INVALID_ARGUMENT,
"Expected scatter reduce computation with 2 parameters and only 1 "
"calculation");
}
auto root_instruction = computation->root_instruction();
if (root_instruction->opcode() == HloOpcode::kAdd ||
root_instruction->opcode() == HloOpcode::kOr) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMultiply ||
root_instruction->opcode() == HloOpcode::kAnd) {
return std::make_pair(HloInstruction::CreateConstant(
LiteralUtil::One(scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMaximum) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
scatter.shape().element_type())),
root_instruction->opcode());
} else if (root_instruction->opcode() == HloOpcode::kMinimum) {
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
scatter.shape().element_type())),
root_instruction->opcode());
}
return Status(tensorflow::error::Code::INVALID_ARGUMENT,
"Expected scatter reduce computation which is "
"add/or/multiply/add/min/max");
}
std::vector<int64> DevicesForSharding(
const HloSharding& sharding, const std::vector<int64>& available_devices) {
std::vector<int64> devices;
if (sharding.IsReplicated()) {
for (int64 d : available_devices) {
if (!HloSharding::IsReservedDevice(d)) {
devices.push_back(d);
}
}
return devices;
}
for (int64 i : available_devices) {
if (sharding.UsesDevice(i)) {
devices.push_back(i);
}
}
DCHECK(std::all_of(sharding.tile_assignment().begin(),
sharding.tile_assignment().end(), [&](int64 device) {
return std::find(available_devices.begin(),
available_devices.end(),
device) != available_devices.end();
}));
return devices;
}
} // namespace hlo_sharding_util
} // namespace xla

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
#include <map>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
namespace xla {
namespace hlo_sharding_util {
// Given a map<device, occurrence_count>, selects the device with higher
// occurrence count (if any). If top_count in not nullptr, it will receive the
// count of the dominant device returned.
absl::optional<int64> SelectDominantDevice(
const std::map<int64, int64>& device_map, int64* top_count);
// Assigns all the instructions of a computation, to a given device.
// This API does not recurse into called computations, and does not assign
// instructions which already have sharding.
Status AssignComputationDevice(HloComputation* computation, int64 device);
// Given an instruction container, returns the device which is most commonly
// occurring among the instructions.
absl::optional<int64> GetMostOccurringDevice(
absl::Span<HloInstruction* const> instructions);
// Given a set of computations, tries to extract the dominant device. A device
// is dominant if the combined occurrence among all the instructions of the
// input computations, is greater/equal than/to dominant_factor (real number
// from 0 to 1).
// This API does not recurse into called computations.
// If no device exists that satisfies the condition, the returned optional will
// hold no value.
StatusOr<absl::optional<int64>> GetDominantDevice(
absl::Span<HloComputation* const> computations, double dominant_factor);
// Returns the HloSharding with the tile dimensions and tile assignment
// transposed based on the specified dimension numbers. In case of a tile
// maximal sharding returns the original sharding.
HloSharding TransposeSharding(const HloSharding& sharding,
const std::vector<int64>& dimensions);
// Returns the HloSharding with the tile shape reshaped based on the source and
// target shapes and the tile assignment adjusted to correspond to the new tile
// shape or absl::nullopt if the resulting reshape would create an invalid
// sharding (non continuous or non uniformly sized tiles). In case of a tile
// maximal sharding returns the original sharding.
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
const Shape& target_shape,
const HloSharding& sharding);
// Returns a sharding tiled on unique dimension dim by reshaping the tile
// assignment of the sharding argument. Only dimensions in the dims span
// argument are considered for reshaping, the others are ignored.
// Assumptions: sharding is tile sharded, and dim must be included in dims.
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
absl::Span<const int64> dims);
// Returns true if the provided module includes one or more instructions with
// a tile sharding.
bool ContainsTileSharding(const HloModule& module);
// Returns the preferred output sharding for a gather op based on the sharding
// of the indces.
HloSharding GatherOutputSharding(const HloSharding& index_sharding,
const HloInstruction* hlo);
// Returns the preferred index sharding for a gather op based on the sharding
// of the output.
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
const HloInstruction* hlo);
// Returns a new HloSharding for a gather op so that only non offset dimensions
// are sharded. Assume "result" is returned by this function. It is ensured that
// "GetIndexSharding(result, hlo)" will have the same number of elements as
// "result".
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
// Returns the preferred index sharding for a scatter op based on the sharding
// of the data.
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
const HloInstruction* hlo);
// Returns the preferred data sharding for a scatter op based on the sharding
// of the index.
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
const HloInstruction* hlo);
// Returns a new index sharding for a scatter op so that we only shard on first
// "number of scatter_window_dims" dimensions. Assume "result" is returned by
// this function. It is ensured that "ScatterDataSharding(result, hlo)" will
// have the same number of elements as "result".
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
const HloInstruction& hlo);
// Returns a new data sharding for a scatter op so that we only shard on
// scatter_window_dims. Assume "result" is returned by this function. It is
// ensured that "ScatterIndexSharding(result, hlo)" will have the same number of
// elements as "result".
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
const HloInstruction& hlo);
// Returns an identity value and an HloOpcode for reduce computation of scatter
// instruction.
// - If computation is add/or, return 0/false with corresponding op code;
// - If computation is multiply/and, return 1/true with corresponding op code.
// - If computation is min/max, return max value/min value with corresponding op
// code.
// - Otherwise, return error status.
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(
const HloScatterInstruction& scatter);
// Given a sharding and a list of devices in the topology, return a
// list of the devices that `sharding` applies to.
std::vector<int64> DevicesForSharding(
const HloSharding& sharding, const std::vector<int64>& available_devices);
} // namespace hlo_sharding_util
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_

View File

@ -0,0 +1,206 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/test.h"
namespace xla {
namespace hlo_sharding_util {
namespace {
TEST(HloShardingUtilTest, TransposeShardingReplicated) {
EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}),
HloSharding::Replicate());
}
TEST(HloShardingUtilTest, TransposeShardingTiled) {
HloSharding input = HloSharding::Tile(Array4D<int64>({{{{0, 1}}, {{2, 3}}}}));
HloSharding output =
HloSharding::Tile(Array4D<int64>({{{{0}, {2}}}, {{{1}, {3}}}}));
EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output);
}
TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5});
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2});
HloSharding sharding = HloSharding::AssignDevice(7);
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) {
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5});
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2});
HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_FALSE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) {
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7});
Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7});
HloSharding input_sharding =
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
HloSharding output_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7});
HloSharding input_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}}));
HloSharding output_sharding =
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7});
HloSharding input_sharding =
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
HloSharding output_sharding =
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) {
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14});
Array<int64> sharding_array({2, 1, 1, 1});
sharding_array(0, 0, 0, 0) = 0;
sharding_array(1, 0, 0, 0) = 1;
HloSharding sharding = HloSharding::Tile(sharding_array);
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) {
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7});
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7});
HloSharding input_sharding =
HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}}));
HloSharding output_sharding =
HloSharding::Tile(Array4D<int64>({{{{0}}, {{1}}}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) {
Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16});
Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1});
HloSharding input_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}}));
HloSharding output_sharding =
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, input_sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), output_sharding);
}
TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) {
Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1});
HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}}));
absl::optional<HloSharding> result = ReshapeSharding(shape, shape, sharding);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result.value(), sharding);
}
TEST(HloShardingUtilTest, ReshapeShardingScalar) {
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1});
Shape output_shape = ShapeUtil::MakeShape(F32, {});
HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}}));
absl::optional<HloSharding> result =
ReshapeSharding(input_shape, output_shape, sharding);
EXPECT_FALSE(result.has_value());
}
TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) {
HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}}));
HloSharding result =
ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1});
EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0}, {1}, {2}, {3}}));
}
TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) {
HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}}));
HloSharding result =
ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1});
EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0, 2, 1, 3}}));
}
TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) {
HloSharding sharding =
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
HloSharding result =
ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2});
EXPECT_EQ(
result.tile_assignment(),
Array3D<int64>({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}}));
}
TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) {
HloSharding sharding =
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
HloSharding result =
ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2});
EXPECT_EQ(result.tile_assignment(),
Array3D<int64>({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}}));
}
TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) {
HloSharding sharding =
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
HloSharding result =
ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2});
EXPECT_EQ(result.tile_assignment(),
Array3D<int64>({{{0, 2, 4, 6, 1, 3, 5, 7}}}));
}
TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) {
// Tile sharding in batch dimension, i.e.
// sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}.
HloSharding sharding =
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
// Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0.
HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2,
/*dims=*/{1, 2});
// Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two
// non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually
// reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}.
EXPECT_EQ(result.tile_assignment(),
Array3D<int64>({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}}));
}
} // namespace
} // namespace hlo_sharding_util
} // namespace xla

View File

@ -0,0 +1,69 @@
# Description: SPMD partitioning pass.
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
default_visibility = [":friends"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
cc_library(
name = "spmd_partitioner",
srcs = [
"spmd_partitioner.cc",
"spmd_partitioner_util.cc",
],
hdrs = [
"spmd_partitioner.h",
"spmd_partitioner_util.h",
],
deps = [
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_query",
"//tensorflow/compiler/xla/service:hlo_sharding_util",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/core/platform:numbers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "spmd_partitioner_test",
srcs = ["spmd_partitioner_test.cc"],
deps = [
":spmd_partitioner",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,435 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
namespace xla {
namespace spmd {
struct SpmdPartitionerOptions {
// Always exchange halo on LHS for all convolutions. If false, backprop filter
// convolution exchanges halo on RHS.
bool conv_halo_exchange_always_on_lhs = true;
// The number of instructions to be reported for the highest memory profile
// instructions.
int64 report_instruction_count = 5;
// The minimum size in MiB of an einsum operand to be considered using
// windowed implementation in an HLO loop.
int64 threshold_for_windowed_einsum_mib = 256;
// Whether the entry computations' signature could change after partitioning.
bool allow_module_signature_change = false;
};
// Class to wrap the computation builder to capture information during SPMD
// transformation.
class SpmdBuilder : public HloComputation::Builder {
public:
SpmdBuilder(const std::string& name, HloInstruction* hlo)
: HloComputation::Builder(name) {
visiting_hlo_ = hlo;
}
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
const std::vector<HloInstruction*>& derived_instructions(
HloInstruction* hlo) {
return instructions_.at(hlo);
}
void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; }
HloInstruction* visiting_hlo() const { return visiting_hlo_; }
private:
// Currently visiting instruction.
HloInstruction* visiting_hlo_;
// Map from the currently visiting (old) instruction to new instructions
// created during SPMD partitioning.
HloInstructionMap<std::vector<HloInstruction*>> instructions_;
};
// A set of functions that create the cross-partition collective ops.
struct SPMDCollectiveOpsCreator {
// Function used to create a partition ID HLO.
std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
// Function used to create a cross-partition all-reduce HLO.
std::function<HloInstruction*(SpmdBuilder*, HloInstruction* operand,
HloComputation* reduction, int64 channel_id)>
create_cross_partition_all_reduce;
// Function used to create a cross-partition collective-permute HLO.
std::function<HloInstruction*(
SpmdBuilder*, HloInstruction* operand,
std::vector<std::pair<int64, int64>>& src_dst_pairs,
int64 next_channel_id)>
create_cross_partition_collective_permute;
// Function used to create a cross-partition all-to-all HLO.
std::function<HloInstruction*(
SpmdBuilder*, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
absl::optional<int64> split_dimension)>
create_cross_partition_all_to_all;
};
// Logger to report memory usage during SPMD partitioning.
class SpmdLogger {
public:
explicit SpmdLogger(int64 report_instruction_count)
: report_instruction_count_(report_instruction_count) {}
static std::string ReportBeforePartition(const HloModule& module,
int64 report_instruction_count);
static std::string ReportAfterPartition(const HloModule& module,
int64 report_instruction_count);
// Registers the logging for the groups of instructions created to transform
// the given hlo.
void RegisterLogEntry(HloInstruction* hlo,
const std::vector<HloInstruction*>& group);
std::string MakeReport();
private:
template <typename F>
static std::string ReportMemoryUsage(const HloModule& module, const F& filter,
int64 report_instruction_count);
// A vector of logging messages (one for each original HLO instruction), where
// the first integer of the pair represents the size of the HBM used.
std::vector<std::pair<int64, std::string>> entries_;
int64 report_instruction_count_;
};
class SpmdPartitioningVisitor;
class SpmdPartitioner : public HloModulePass {
public:
SpmdPartitioner(int64 num_partitions, int64 num_replicas,
SpmdPartitionerOptions options);
SpmdPartitioner(int64 num_partitions, int64 num_replicas,
SpmdPartitionerOptions options,
SPMDCollectiveOpsCreator collective_ops_creator)
: num_partitions_(num_partitions),
num_replicas_(num_replicas),
options_(std::move(options)),
collective_ops_creator_(std::move(collective_ops_creator)) {}
absl::string_view name() const override { return "spmd-partitioning"; }
StatusOr<bool> Run(HloModule* module) override;
// Transforms the given computation with SPMD instructions, replacing it with
// a new computation.
StatusOr<bool> PartitionComputation(HloComputation* computation,
const HloSharding& root_sharding,
int64* next_channel_id,
SpmdLogger* logger);
protected:
virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options);
private:
// Verify that the sharding of instructions in the module are valid, and also
// fill in missing sharding information.
Status PreprocessSharding(HloModule* module);
const int64 num_partitions_;
const int64 num_replicas_;
SpmdPartitionerOptions options_;
SPMDCollectiveOpsCreator collective_ops_creator_;
};
// Class describes partition state of the data represented by an HLO created
// during SPMD partitioning pass.
//
// Data on some devices may include padding region, if the base (full) shape
// could not be evenly partitioned.
class PartitionedHlo {
public:
// Return value for ReshardAsWindowedInput which describes the resharded HLO,
// the window for the user on the shard, and if necessary, the dynamic slice
// offsets to be applied to the output of the op being sharded.
struct WindowedInputShardReturnValue {
HloInstruction* sharded_input;
Window shard_window;
absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output;
};
// A cache for resharding each partitioned HLO.
struct ReshardCache {
struct PerHloCache {
std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
std::vector<
std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
window_reshard_cache;
};
std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
};
struct PartitioningState {
SpmdBuilder* b;
HloModule* module;
int64 num_replicas;
HloInstruction* partition_id;
SPMDCollectiveOpsCreator collective_ops_creator;
int64* next_channel_id;
ReshardCache* reshard_cache;
};
PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
: hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
CHECK(hlo->has_sharding())
<< "PartitionedHlo is missing sharding:" << hlo->ToString();
// If the tuple shape instruction does not have a tuple sharding, reassign
// to use the tuple sharding. Reshard() implementation assumes this.
if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) {
hlo_->set_sharding(
hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie());
}
}
// Reshards the current SPMD instruction to a new sharding. Could only modify
// the reshard cache.
PartitionedHlo Reshard(const HloSharding& target);
// Pads the garbage area of the output with the provided value.
PartitionedHlo PadWithValue(HloInstruction* pad_value) const;
// Returns the SPMD instruction.
HloInstruction* hlo() const { return hlo_; }
// Returns the sharding of the SPMD instruction.
const HloSharding& sharding() const { return hlo_->sharding(); }
// Original full shape of the data.
const Shape& base_shape() const { return base_shape_; }
int64 NewChannel() const { return (*state_.next_channel_id)++; }
// Reshards the HLO to a usable partitioned input for a windowed user. Could
// only modify the reshard cache.
absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput(
const Window& window, const HloSharding& target,
HloInstruction* pad_value, bool mask_invalid_region = true);
private:
// Same as Reshard except that it does not explicitly modify the reshard
// cache, although it would indirectly modify by calling Replicate().
PartitionedHlo ReshardNoCache(const HloSharding& target);
// Helper function to replicate the data on all devices. Could only modify
// the reshard cache.
PartitionedHlo Replicate();
// Helper function to broadcast data from a single device to all devices.
PartitionedHlo Broadcast() const;
// Helper function to reshard the tensor using AllToAll (instead of the
// default of Replicate followed by Slice).
PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const;
// Helper function to reshard the tensor using CollectivePermute.
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
// SPMD instruction.
HloInstruction* hlo_;
// The original shape of the data before SPMD transformation is applied.
Shape base_shape_;
PartitioningState state_;
};
struct DotGeneralDimsMapping {
// The dimension numbers for the operands and output corresponding to a
// logical dimension (e.g., batch, contracting, non-contracting). If an
// operand or the output doesn't have the logical dimension, it is set to
// -1.
struct DimsMapping {
int64 lhs;
int64 rhs;
int64 output;
};
std::vector<DimsMapping> batch_dims;
std::vector<DimsMapping> contracting_dims;
std::vector<DimsMapping> lhs_non_contracting_dims;
std::vector<DimsMapping> rhs_non_contracting_dims;
};
class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
public:
SpmdPartitioningVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options, SpmdPartitioner* partitioner);
Status DefaultAction(HloInstruction* hlo) override;
Status HandleAllReduce(HloInstruction* hlo) override;
Status HandleBroadcast(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* hlo) override;
Status HandleCustomCall(HloInstruction* hlo) override;
Status HandleDot(HloInstruction* hlo) override;
Status HandleDynamicSlice(HloInstruction* hlo) override;
Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
Status HandleGather(HloInstruction* hlo) override;
Status HandleGetTupleElement(HloInstruction* hlo) override;
Status HandleInfeed(HloInstruction* hlo) override;
Status HandleOutfeed(HloInstruction* hlo) override;
Status HandlePad(HloInstruction* hlo) override;
Status HandleParameter(HloInstruction* hlo) override;
Status HandleReduce(HloInstruction* hlo) override;
Status HandleReverse(HloInstruction* hlo) override;
Status HandleWhile(HloInstruction* hlo) override;
Status HandleConditional(HloInstruction* hlo) override;
Status HandleReduceWindow(HloInstruction* hlo) override;
Status HandleSelectAndScatter(HloInstruction* hlo) override;
Status HandleTuple(HloInstruction* hlo) override;
Status HandleRng(HloInstruction* hlo) override;
Status HandleConvolution(HloInstruction* hlo) override;
Status HandleConcatenate(HloInstruction* hlo) override;
Status HandleScatter(HloInstruction* hlo) override;
Status HandleSlice(HloInstruction* hlo) override;
Status HandleSort(HloInstruction* hlo) override;
Status HandleTranspose(HloInstruction* hlo) override;
Status HandleReshape(HloInstruction* hlo) override;
Status HandleIota(HloInstruction* hlo) override;
Status HandlePartitionId(HloInstruction* hlo) override;
// Handles convolution where both LHS and RHS operands are tiled.
Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo);
// Implementation of dot partitioning given DotGeneralDimsMapping.
Status HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
// Common handle for elementwise HLOs.
Status HandleElementwise(HloInstruction* hlo);
// Common handle for HLOs that runs on a single device.
Status HandleSingleDevice(const HloInstruction* hlo);
// Returns the PartitionedHlo that corresponds to the original hlo.
PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) {
CHECK_EQ(partitioned_instructions_.count(hlo), 1);
return partitioned_instructions_.find(hlo)->second;
}
// Sets the PartitionedHlo for the original hlo.
void SetPartitionedHlo(const HloInstruction* hlo,
const PartitionedHlo& partitioned_hlo) {
CHECK_EQ(partitioned_instructions_.count(hlo), 0);
partitioned_instructions_.emplace(hlo, partitioned_hlo);
changed_ = true;
}
// Convenient wrapper that creates PartitionedHlo from the result of the func
// and maps it to the given original hlo.
void SetPartitionedHlo(const HloInstruction* hlo,
const std::function<HloInstruction*()>& func) {
HloInstruction* new_hlo = func();
new_hlo->set_sharding(hlo->sharding());
new_hlo->set_metadata(hlo->metadata());
SetPartitionedHlo(
hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
changed_ = true;
}
int64 NewChannel() { return (*next_channel_id_)++; }
PartitionedHlo::PartitioningState MakePartitioningState() {
return PartitionedHlo::PartitioningState{
.b = &b_,
.module = module_,
.num_replicas = num_replicas_,
.partition_id = partition_id_,
.collective_ops_creator = collective_ops_creator_,
.next_channel_id = next_channel_id_,
.reshard_cache = &reshard_cache_};
}
SpmdBuilder* builder() { return &b_; }
StatusOr<bool> DoPartition(HloComputation* computation,
const HloSharding& root_sharding);
private:
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
// Performs code motion for windowed dot-general loops in
// windowed_dot_general_loops_. Invoked after the visitor finishes traversing
// the graph.
Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation);
bool changed_;
HloModule* module_;
int64 num_partitions_;
int64 num_replicas_;
SPMDCollectiveOpsCreator collective_ops_creator_;
// Tracks the next channel id to use for cross-partition all-reduce.
int64* next_channel_id_;
SpmdBuilder b_;
HloInstruction* partition_id_;
PartitionedHlo::ReshardCache reshard_cache_;
// Mapping from the instruction in the original computation to the new SPMD
// partitioned instruction.
ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
// Information about a loop created for windowed dot-general. Used when
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
// finishes traversing the graph.
struct WindowedDotGeneralLoop {
HloInstruction* while_loop;
int64 windowed_operand;
bool windowed_in_contracting_dims;
bool windowed_in_batch_dims;
};
std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
HloInstruction* visiting_hlo_;
SpmdLogger* logger_;
const SpmdPartitionerOptions options_;
SpmdPartitioner* partitioner_;
};
} // namespace spmd
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,662 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace spmd {
bool HasReplicatedSharding(const HloSharding& sharding) {
if (sharding.IsTuple()) {
return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
}
return sharding.IsReplicated();
}
HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
if (shape.IsTuple()) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
elements.push_back(
CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b));
}
return b->AddInstruction(HloInstruction::CreateTuple(elements));
}
if (shape.IsToken()) {
return b->AddInstruction(HloInstruction::CreateToken());
}
auto zero = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
}
HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
HloComputation::Builder sum_b("add");
auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
if (type == PRED) {
sum_b.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
} else {
sum_b.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
}
HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
return reduction;
}
bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
if (sharding.IsTuple()) {
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
sharding.GetSubSharding(shape, {i}))) {
return false;
}
}
}
if (sharding.IsTileMaximal()) {
return sharding.IsReplicated();
}
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
return false;
}
}
return true;
}
Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
if (sharding.IsTuple()) {
std::vector<Shape> subshapes;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
subshapes.push_back(
MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
sharding.GetSubSharding(shape, {i})));
}
return ShapeUtil::MakeTupleShape(subshapes);
}
return sharding.TileShape(shape);
}
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
const HloSharding& sharding,
int64 partition_id) {
if (sharding.IsTuple()) {
std::vector<Shape> subshapes;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
ShapeUtil::GetTupleElementShape(shape, i),
sharding.GetSubSharding(shape, {i}), partition_id));
}
return ShapeUtil::MakeTupleShape(subshapes);
}
auto partition_shape = shape;
std::vector<int64> tile_offset =
sharding.TileOffsetForDevice(shape, partition_id);
std::vector<int64> tile_limit =
sharding.TileLimitForDevice(shape, partition_id);
for (int64 i = 0; i < tile_offset.size(); ++i) {
if (sharding.UsesDevice(partition_id)) {
partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
} else {
partition_shape.set_dimensions(i, 0);
}
}
return partition_shape;
}
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
const HloSharding& sharding,
HloInstruction* partition_id,
SpmdBuilder* b) {
CHECK(!shape.IsTuple());
Array2D<int32> offset_array(
{sharding.tile_assignment().num_elements(), shape.rank()});
offset_array.Each([&](int64 i, int64 j, int32* value) {
*value = sharding.TileOffsetForDevice(shape, i)[j];
});
auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2FromArray2D(offset_array)));
std::vector<HloInstruction*> offsets;
for (int64 i = 0; i < shape.rank(); ++i) {
if (sharding.tile_assignment().dim(i) == 1) {
offsets.push_back(b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
} else {
auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {1, 1}), offset_table,
{partition_id, b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(i)))},
{1, 1}));
offsets.push_back(b->AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
}
}
return offsets;
}
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
CHECK(!sharding.IsTileMaximal());
auto table_shape =
ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions());
return MakePartitionOffsets(table_shape, sharding, partition_id, b);
}
HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
SpmdBuilder* b, HloComputation* computation) {
CHECK(b == nullptr || computation == nullptr);
if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
return hlo;
}
PaddingConfig padding_config;
for (int64 i = 0; i < padded_shape.rank(); ++i) {
auto padding_config_dim = padding_config.add_dimensions();
padding_config_dim->set_edge_padding_low(0);
padding_config_dim->set_interior_padding(0);
padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
hlo->shape().dimensions(i));
}
auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
if (b == nullptr) {
return computation->AddInstruction(std::move(to_add));
}
return b->AddInstruction(std::move(to_add));
};
auto zero = add_hlo(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
return add_hlo(
HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config));
}
Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
const HloSharding& sharding) {
if (sharding.IsTileMaximal()) {
return base_shape;
}
if (EvenlyPartitions(base_shape, sharding)) {
return base_shape;
}
auto shard_shape = MakePartitionedShape(base_shape, sharding);
Shape padded_base_shape = base_shape;
for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
padded_base_shape.set_dimensions(
i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
}
return padded_base_shape;
}
HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) {
auto padded_base_shape =
GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
return hlo;
}
return PadToShape(hlo, padded_base_shape, b);
}
absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) {
if (sharding.IsTileMaximal()) {
return absl::nullopt;
}
int64 dim = -1;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (sharding.tile_assignment().dim(i) > 1) {
if (dim != -1) {
return absl::nullopt;
}
dim = i;
}
}
CHECK_NE(dim, -1);
return dim;
}
MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
int64 multiplier, int64 offset, int64 divisor)
: multiplier_(multiplier), offset_(offset), divisor_(divisor) {
CHECK_GT(divisor_, 0);
Simplify();
}
OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
const MultiplyAddDivideOffsetCalculation& other) const {
if (divisor_ == 1 && other.divisor_ == 1) {
return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
}
return OffsetCalculation(HloOpcode::kSubtract, *this, other);
}
void MultiplyAddDivideOffsetCalculation::Simplify() {
// We could simplify the calculation when multiplier is a multiple of
// divisor_. However, when offset_ is not a multiple of divisor_, we must
// make sure that offset_ and multiplier_ are both non-negative or both
// non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1.
if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
(offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
multiplier_ /= divisor_;
offset_ /= divisor_;
divisor_ = 1;
}
}
int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const {
return (shard_ordinal * multiplier_ + offset_) / divisor_;
}
HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
HloInstruction* shard_ordinal, SpmdBuilder* b) const {
auto scalar_shape = ShapeUtil::MakeShape(S32, {});
if (multiplier_ == 0) {
return b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(offset_ / divisor_)));
}
HloInstruction* result = shard_ordinal;
if (multiplier_ != 1) {
result = b->AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kMultiply, shard_ordinal,
b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(multiplier_)))));
}
if (offset_ != 0) {
auto offset = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_)));
result = b->AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, result, offset));
}
if (divisor_ != 1) {
auto divisor = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_)));
result = b->AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kDivide, result, divisor));
}
return result;
}
int64 MultiplyAddDivideOffsetCalculation::MaxInRange(
int64 start_ordinal, int64 limit_ordinal) const {
int64 max = Calculate(start_ordinal);
for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
max = std::max(max, Calculate(i));
}
return max;
}
OffsetCalculation& OffsetCalculation::operator=(
const OffsetCalculation& other) {
opcode_ = other.opcode_;
copy_from_ = other.copy_from_;
if (opcode_ != HloOpcode::kCopy) {
lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_);
rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_);
}
return *this;
}
bool OffsetCalculation::IsConstant() const {
if (opcode_ == HloOpcode::kCopy) {
return copy_from_.IsConstant();
}
if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
return true;
}
return lhs_->IsConstant() && rhs_->IsConstant();
}
OffsetCalculation OffsetCalculation::operator-(
const OffsetCalculation& other) const {
if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
return copy_from_ - other.copy_from_;
}
return OffsetCalculation(HloOpcode::kSubtract, *this, other);
}
bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
if (opcode_ != other.opcode_) {
return false;
}
if (opcode_ == HloOpcode::kCopy) {
return copy_from_ == other.copy_from_;
}
return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
}
int64 OffsetCalculation::Calculate(int64 shard_ordinal) const {
switch (opcode_) {
case HloOpcode::kCopy:
return copy_from_.Calculate(shard_ordinal);
case HloOpcode::kSubtract:
return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
case HloOpcode::kMultiply:
return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
default:
LOG(FATAL) << "Should not happen";
}
}
HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
SpmdBuilder* b) const {
if (opcode_ == HloOpcode::kCopy) {
return copy_from_.Calculate(shard_ordinal, b);
}
auto lhs = lhs_->Calculate(shard_ordinal, b);
auto rhs = rhs_->Calculate(shard_ordinal, b);
return b->AddInstruction(
HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
}
int64 OffsetCalculation::MaxInRange(int64 start_ordinal,
int64 limit_ordinal) const {
if (IsConstant()) {
return Calculate(start_ordinal);
}
if (opcode_ == HloOpcode::kCopy) {
return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
}
int64 max = Calculate(start_ordinal);
for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
max = std::max(max, Calculate(i));
}
return max;
}
absl::optional<HloInstruction*> ExchangeHalo(
HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
const OffsetCalculation& right_halo_size_function, int64 dim,
const HloSharding& target,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b) {
int64 input_shard_size = hlo->shape().dimensions(dim);
int64 shard_count = target.tile_assignment().dim(dim);
std::vector<HloInstruction*> concat_pieces;
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
if (max_left_halo_size > input_shard_size) {
VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor.";
return absl::nullopt;
}
if (max_left_halo_size > 0) {
std::vector<std::pair<int64, int64>> source_target_pairs;
target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
if (indices[dim] > 0) {
std::vector<int64> source_indices(indices.begin(), indices.end());
source_indices[dim] -= 1;
source_target_pairs.emplace_back(
target.tile_assignment()(source_indices), device);
}
});
auto halo_shape = hlo->shape();
auto source_halo_slice = hlo;
if (max_left_halo_size != hlo->shape().dimensions(dim)) {
halo_shape.set_dimensions(dim, max_left_halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
halo_start_indices[dim] =
hlo->shape().dimensions(dim) - max_left_halo_size;
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
source_halo_slice = b->AddInstruction(
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
hlo->shape().dimensions(), halo_slice_strides));
}
auto left_halo =
collective_ops_creator.create_cross_partition_collective_permute(
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
concat_pieces.push_back(left_halo);
}
concat_pieces.push_back(hlo);
// Right halo.
int64 max_right_halo_size =
right_halo_size_function.MaxInRange(0, shard_count - 1);
if (max_right_halo_size > input_shard_size) {
VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor.";
return absl::nullopt;
}
if (max_right_halo_size > 0) {
std::vector<std::pair<int64, int64>> source_target_pairs;
target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
if (indices[dim] > 0) {
std::vector<int64> target_indices(indices.begin(), indices.end());
target_indices[dim] -= 1;
source_target_pairs.emplace_back(
device, target.tile_assignment()(target_indices));
}
});
auto halo_shape = hlo->shape();
halo_shape.set_dimensions(dim, max_right_halo_size);
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
auto source_halo_slice = b->AddInstruction(
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
halo_shape.dimensions(), halo_slice_strides));
auto right_halo =
collective_ops_creator.create_cross_partition_collective_permute(
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
concat_pieces.push_back(right_halo);
}
auto concat = hlo;
// Concat with halos/padding.
if (concat_pieces.size() > 1) {
auto concat_shape = hlo->shape();
int64 concat_dim_size = 0;
for (auto piece : concat_pieces) {
concat_dim_size += piece->shape().dimensions(dim);
}
concat_shape.set_dimensions(dim, concat_dim_size);
concat = b->AddInstruction(
HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
}
return concat;
}
absl::optional<HloInstruction*> ExchangeHalo(
HloInstruction* hlo,
std::vector<OffsetCalculation> left_halo_size_functions,
std::vector<OffsetCalculation> right_halo_size_functions,
const HloSharding& target,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b) {
CHECK(left_halo_size_functions.size() == hlo->shape().rank());
CHECK(right_halo_size_functions.size() == hlo->shape().rank());
HloInstruction* visiting_hlo = hlo;
for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
right_halo_size_functions[dim], dim, target,
collective_ops_creator, next_channel_id, b);
if (!concat) {
return absl::nullopt;
}
visiting_hlo = *concat;
}
return visiting_hlo;
}
absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
HloInstruction* hlo, const Shape& base_shape,
const OffsetCalculation& left_halo_size_function,
const OffsetCalculation& right_halo_size_function,
int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
int64 shard_size_with_halo, int64 dim, const HloSharding& target,
HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
HloInstruction* partition_ordinal,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
auto halo_exchange_result =
ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
target, collective_ops_creator, next_channel_id, b);
if (!halo_exchange_result) {
return absl::nullopt;
}
auto concat = *halo_exchange_result;
int64 shard_count = target.tile_assignment().dim(dim);
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
// Now we determine if we need extra padding after the concat.
//
// The max of halo size or the first shard's explicit left padding.
int64 max_left_halo_or_padding_size =
std::max(std::max(int64{0}, max_left_halo_size),
explicit_left_padding_on_full_shape);
// The calculation that returns the dynamic slice index for a shard on the
// padded concat, which is the difference between
// max_left_halo_or_padding_size and its left halo size.
auto start_offset_on_padded_concat_calculation =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
0, max_left_halo_or_padding_size, 1)) -
left_halo_size_function;
// See if we need to pad the concat before dynamic slice.
int64 extra_left_padding =
std::max(int64{0}, max_left_halo_or_padding_size -
std::max(int64{0}, max_left_halo_size));
int64 extra_right_padding =
start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
shard_size_with_halo - concat->shape().dimensions(dim) -
extra_left_padding;
extra_right_padding = std::max(int64{0}, extra_right_padding);
if (extra_left_padding > 0 || extra_right_padding > 0) {
PaddingConfig padding_config;
auto padded_concat_shape = concat->shape();
for (int64 i = 0; i < base_shape.rank(); ++i) {
auto padding_config_dim = padding_config.add_dimensions();
padding_config_dim->set_interior_padding(0);
padding_config_dim->set_edge_padding_low(0);
padding_config_dim->set_edge_padding_high(0);
if (i != dim) {
continue;
}
padding_config_dim->set_edge_padding_low(extra_left_padding);
padding_config_dim->set_edge_padding_high(extra_right_padding);
padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
extra_left_padding +
extra_right_padding);
}
concat = b->AddInstruction(HloInstruction::CreatePad(
padded_concat_shape, concat, pad_value, padding_config));
}
auto valid_slice = concat;
if (shard_size_with_halo != concat->shape().dimensions(dim)) {
// Concat is bigger than the shard shape, so we need a dynamic slice.
CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
auto slice_shape = concat->shape();
slice_shape.set_dimensions(dim, shard_size_with_halo);
if (left_halo_size_function.IsConstant() &&
left_halo_size_function.Calculate(0) ==
explicit_left_padding_on_full_shape) {
std::vector<int64> start_indices(slice_shape.rank(), 0);
std::vector<int64> strides(slice_shape.rank(), 1);
valid_slice = b->AddInstruction(
HloInstruction::CreateSlice(slice_shape, concat, start_indices,
slice_shape.dimensions(), strides));
} else {
auto zero = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
partition_ordinal, b);
valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
slice_shape, concat, slice_offsets, slice_shape.dimensions()));
}
}
if (!mask_invalid_region) {
return valid_slice;
}
int64 total_right_padding = padded_full_shape_size -
base_shape.dimensions(dim) -
explicit_left_padding_on_full_shape;
// Mask off garbage data due to uneven partition or low/high padding.
if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
auto broadcast_start_index_in_padded_shape =
b->AddInstruction(HloInstruction::CreateBroadcast(
index_shape, offset_on_padded_shape, {}));
auto index_in_padded_shape = b->AddInstruction(
HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
broadcast_start_index_in_padded_shape));
auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
std::vector<HloInstruction*> predicates;
if (explicit_left_padding_on_full_shape > 0) {
auto valid_index_start =
b->AddInstruction(HloInstruction::CreateBroadcast(
index_shape,
b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
explicit_left_padding_on_full_shape))),
{}));
predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
mask_shape, index_in_padded_shape, valid_index_start,
ComparisonDirection::kGe)));
}
if (total_right_padding > 0) {
auto valid_index_limit =
b->AddInstruction(HloInstruction::CreateBroadcast(
index_shape,
b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
base_shape.dimensions(dim) +
explicit_left_padding_on_full_shape))),
{}));
predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
mask_shape, index_in_padded_shape, valid_index_limit,
ComparisonDirection::kLt)));
}
CHECK(!predicates.empty());
auto is_valid =
predicates.size() == 2
? b->AddInstruction(HloInstruction::CreateBinary(
mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
: predicates[0];
auto masking_value = b->AddInstruction(
HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
valid_slice = b->AddInstruction(
HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
is_valid, valid_slice, masking_value));
}
return valid_slice;
}
} // namespace spmd
} // namespace xla

View File

@ -0,0 +1,229 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
#include <memory>
#include <string>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
namespace xla {
namespace spmd {
// Returns true if the given sharding contains any replicated sharding.
bool HasReplicatedSharding(const HloSharding& sharding);
// Creates zero value instructions of the given shape.
HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b);
template <typename NativeT>
HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value,
SpmdBuilder* b) {
auto literal = LiteralUtil::CreateR0(value)
.ConvertToShape(ShapeUtil::MakeShape(type, {}))
.ValueOrDie();
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
}
// Create a binary add computation of the given type and add to the module.
HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
// Returns true if the shape can be evenly partitioned for the given sharding.
// All tile sharded dimensions should be evenly divisible and there should be no
// single-device sharding. Replicate sharding is considered even partition.
bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
// Returns the shard shape of the given shape when it is partitioned for the
// target sharding.
Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
// Returns the shard shape for a partition without padding due to uneven
// sharding.
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
const HloSharding& sharding,
int64 partition_id);
// Generates the HLO instructions that represent the dimension offsets on any
// device. The size of the returned vector is the rank of the given shape.
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
const HloSharding& sharding,
HloInstruction* partition_id,
SpmdBuilder* b);
// Returns the offsets of the partition in the tile assignment.
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
// Pads hlo to the desired shape using high padding. Either a builder or a
// computation needs to be supplied, but not both.
HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
SpmdBuilder* b,
HloComputation* computation = nullptr);
// Returns the padded shape when combining all partitions.
Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
const HloSharding& sharding);
// Pads the HLO (with base shape) for uneven tiled partition to make it evenly
// partitionable.
HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
// Returns the index of the unique tile dimension. Returns absl::nullopt if the
// given sharding is not tiled or tiled along multiple dimensions.
absl::optional<int64> UniqueTiledDim(const HloSharding& sharding);
// Utilities for symbolic offset calculation and halo exchange.
class OffsetCalculation;
// Represents a calculation over integers:
// (shard_ordinal * multiplier + offset) / divisor
class MultiplyAddDivideOffsetCalculation {
public:
MultiplyAddDivideOffsetCalculation()
: multiplier_(0), offset_(0), divisor_(1) {}
MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset,
int64 divisor);
OffsetCalculation operator-(
const MultiplyAddDivideOffsetCalculation& other) const;
bool operator==(const MultiplyAddDivideOffsetCalculation& other) const {
return multiplier_ == other.multiplier_ && offset_ == other.offset_ &&
divisor_ == other.divisor_;
}
bool IsConstant() const { return multiplier_ == 0; }
void Simplify();
int64 Calculate(int64 shard_ordinal) const;
HloInstruction* Calculate(HloInstruction* shard_ordinal,
SpmdBuilder* b) const;
// Returns the maximum result for shard ordinals in the range
// [start_ordinal, limit_ordinal).
int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
private:
int64 multiplier_;
int64 offset_;
int64 divisor_;
};
// Represents a calculation over integers based on results of other calculations
// defined by an opcode. If the opcode is kCopy, it simply wraps an
// MultiplyAddDivideOffsetCalculation.
class OffsetCalculation {
public:
OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {}
explicit OffsetCalculation(
const MultiplyAddDivideOffsetCalculation& copy_from)
: opcode_(HloOpcode::kCopy), copy_from_(copy_from) {}
OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; }
OffsetCalculation(HloOpcode opcode,
const MultiplyAddDivideOffsetCalculation& lhs,
const MultiplyAddDivideOffsetCalculation& rhs)
: opcode_(opcode),
lhs_(absl::make_unique<OffsetCalculation>(lhs)),
rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs,
const OffsetCalculation& rhs)
: opcode_(opcode),
lhs_(absl::make_unique<OffsetCalculation>(lhs)),
rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
OffsetCalculation& operator=(const OffsetCalculation& other);
// Returns whether the calculation returns the same value for all shards. This
// is conservative and could return false even if it is actually constant.
bool IsConstant() const;
OffsetCalculation operator-(const OffsetCalculation& other) const;
bool operator==(const OffsetCalculation& other) const;
int64 Calculate(int64 shard_ordinal) const;
HloInstruction* Calculate(HloInstruction* shard_ordinal,
SpmdBuilder* b) const;
// Returns the maximum result for shard ordinals in the range
// [start_ordinal, limit_ordinal).
int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
private:
HloOpcode opcode_;
std::unique_ptr<OffsetCalculation> lhs_;
std::unique_ptr<OffsetCalculation> rhs_;
MultiplyAddDivideOffsetCalculation copy_from_;
};
// Performs halo exchange on the given dimension based on the provided
// left/right halo size functions. Returns nullopt if the halo is beyond the
// direct neighbor of the shard.
absl::optional<HloInstruction*> ExchangeHalo(
HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
const OffsetCalculation& right_halo_size_function, int64 dim,
const HloSharding& target,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b);
// Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
// dimensions fails to exchange halo (halo is beyond the neighbor shard).
absl::optional<HloInstruction*> ExchangeHalo(
HloInstruction* hlo,
std::vector<OffsetCalculation> left_halo_size_functions,
std::vector<OffsetCalculation> right_halo_size_functions,
const HloSharding& target,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b);
// Exchanges halos and performs pad/dynamic-slice on the concatenated data such
// that the result starts with the first needed element on each shard. It also
// masks off invalid data due to padding.
// Arguments:
// hlo: the HLO op before halo exchange
// explicit_left_padding_on_full_shape: the amount of left padding to be added
// explicitly by this function on the base shape before partitioning. Without
// base dilation, this is usually set to the window's padding_low so that the
// sharded op do not need to add padding_low on the window; however, with base
// dilation, this could only be set to a custom size.
// padded_full_shape_size: the size of the padded full shape on the given
// dimension, which includes explicit_left_padding_on_full_shape and required
// right padding to make the shape evenly shardable.
// shard_size_with_halo: the shard size on the dimension after halo exchange.
// If different shards have different sizes, use the maximum size.
// offset_on_padded_shape: the offset HLO (S32) that represents the start of
// each shard on the padded full shape.
// pad_value: the padding value used on the full shape.
absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
HloInstruction* hlo, const Shape& base_shape,
const OffsetCalculation& left_halo_size_function,
const OffsetCalculation& right_halo_size_function,
int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
int64 shard_size_with_halo, int64 dim, const HloSharding& target,
HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
HloInstruction* partition_ordinal,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
} // namespace spmd
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_