[XLA] Move SPMD partitioner to third_party
This change moves the work on SPMD partitioning that the XLA team has been working on in the past 12 months. PiperOrigin-RevId: 311367525 Change-Id: If174527128c222c53736dc8db2ef1ea4177fb476
This commit is contained in:
parent
59239ab499
commit
d45abae4e9
@ -460,6 +460,37 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_sharding_util",
|
||||
srcs = [
|
||||
"hlo_sharding_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"hlo_sharding_util.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_sharding_util_test",
|
||||
srcs = [
|
||||
"hlo_sharding_util_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":hlo_sharding_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "dynamic_parameter_binding_test",
|
||||
srcs = ["dynamic_parameter_binding_test.cc"],
|
||||
|
||||
574
tensorflow/compiler/xla/service/hlo_sharding_util.cc
Normal file
574
tensorflow/compiler/xla/service/hlo_sharding_util.cc
Normal file
@ -0,0 +1,574 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace hlo_sharding_util {
|
||||
|
||||
absl::optional<int64> SelectDominantDevice(
|
||||
const std::map<int64, int64>& device_map, int64* top_count) {
|
||||
int64 device = 0;
|
||||
int64 count = 0;
|
||||
for (auto& it : device_map) {
|
||||
if (it.second > count) {
|
||||
count = it.second;
|
||||
device = it.first;
|
||||
}
|
||||
}
|
||||
if (top_count != nullptr) {
|
||||
*top_count = count;
|
||||
}
|
||||
return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
|
||||
}
|
||||
|
||||
Status AssignComputationDevice(HloComputation* computation, int64 device) {
|
||||
VLOG(4) << "Assigning device " << device << " to " << computation->name()
|
||||
<< " computation";
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (!instruction->has_sharding()) {
|
||||
VLOG(4) << "Assigning device " << device << " to " << instruction->name();
|
||||
instruction->set_device_sharding(device);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
absl::optional<int64> GetMostOccurringDevice(
|
||||
absl::Span<HloInstruction* const> instructions) {
|
||||
std::map<int64, int64> device_map;
|
||||
for (HloInstruction* instruction : instructions) {
|
||||
if (instruction->has_sharding()) {
|
||||
for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
|
||||
// The UsedDevices() API returns a map<device, occurrence_count>.
|
||||
device_map[it.first] += it.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
return SelectDominantDevice(device_map, nullptr);
|
||||
}
|
||||
|
||||
StatusOr<absl::optional<int64>> GetDominantDevice(
|
||||
absl::Span<HloComputation* const> computations, double dominant_factor) {
|
||||
int64 instruction_count = 0;
|
||||
std::map<int64, int64> device_map;
|
||||
for (HloComputation* computation : computations) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
int64 count = 1;
|
||||
if (instruction->has_sharding()) {
|
||||
for (auto& it : instruction->sharding().UsedDevices(&count)) {
|
||||
// The UsedDevices() API returns a map<device, occurrence_count>.
|
||||
device_map[it.first] += it.second;
|
||||
}
|
||||
}
|
||||
instruction_count += count;
|
||||
}
|
||||
}
|
||||
int64 count;
|
||||
absl::optional<int64> device = SelectDominantDevice(device_map, &count);
|
||||
absl::optional<int64> dominant_device;
|
||||
if (device) {
|
||||
double factor =
|
||||
static_cast<double>(count) / static_cast<double>(instruction_count);
|
||||
if (factor >= dominant_factor) {
|
||||
dominant_device = device;
|
||||
}
|
||||
}
|
||||
return dominant_device;
|
||||
}
|
||||
|
||||
HloSharding TransposeSharding(const HloSharding& sharding,
|
||||
const std::vector<int64>& dimensions) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return sharding;
|
||||
}
|
||||
const int64 rank = dimensions.size();
|
||||
std::vector<int64> tile_assignment_dim(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]);
|
||||
}
|
||||
Array<int64> tile_assignment = sharding.tile_assignment();
|
||||
tile_assignment.Reshape(tile_assignment_dim);
|
||||
tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
|
||||
std::vector<int64> src_indices(indices.size(), -1);
|
||||
for (int64 i = 0; i < indices.size(); ++i) {
|
||||
src_indices[dimensions[i]] = indices[i];
|
||||
}
|
||||
*value = sharding.tile_assignment()(src_indices);
|
||||
});
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
||||
const Shape& target_shape,
|
||||
const HloSharding& sharding) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return sharding;
|
||||
}
|
||||
|
||||
// In case of a tiled sharding the reshaped sharding will be a valid if the
|
||||
// reshape is composed from the following operations:
|
||||
// * Adding or removing dimensions with size 1.
|
||||
// * Merging consecutive dimensions where only the most major is sharded.
|
||||
// * Splitting a dimension to consecutive dimensions.
|
||||
// * Any reshaping of unsharded dimensions.
|
||||
// Note that merge and split can happen consecutively on the same dimension,
|
||||
// e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
|
||||
// gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
|
||||
// to make supporting such cases easy.
|
||||
const Shape tile_shape = sharding.TileShape(source_shape);
|
||||
std::vector<int64> target_tile_assignment_dimensions;
|
||||
std::vector<int64> source_dims_stack(source_shape.rank());
|
||||
std::vector<int64> target_dims_stack(target_shape.rank());
|
||||
std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
|
||||
for (int64 i = 0; i < source_shape.rank(); ++i) {
|
||||
source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
|
||||
sharding_tile_dims_stack[i] =
|
||||
sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
|
||||
}
|
||||
for (int64 i = 0; i < target_shape.rank(); ++i) {
|
||||
target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
|
||||
}
|
||||
while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
|
||||
if (target_dims_stack.empty()) {
|
||||
if (Product(sharding_tile_dims_stack) != 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
break;
|
||||
}
|
||||
int64 s_size = 1;
|
||||
int64 t_size = 1;
|
||||
int64 s_partitions = 1;
|
||||
if (!source_dims_stack.empty()) {
|
||||
s_size = source_dims_stack.back();
|
||||
source_dims_stack.pop_back();
|
||||
s_partitions = sharding_tile_dims_stack.back();
|
||||
sharding_tile_dims_stack.pop_back();
|
||||
}
|
||||
t_size = target_dims_stack.back();
|
||||
target_dims_stack.pop_back();
|
||||
if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
|
||||
// No more partitions left.
|
||||
target_tile_assignment_dimensions.push_back(1);
|
||||
continue;
|
||||
}
|
||||
if (s_size == t_size) {
|
||||
// Same dimension.
|
||||
target_tile_assignment_dimensions.push_back(s_partitions);
|
||||
} else if (t_size == 1) {
|
||||
// Trivial dimension added.
|
||||
target_tile_assignment_dimensions.push_back(1);
|
||||
source_dims_stack.push_back(s_size);
|
||||
sharding_tile_dims_stack.push_back(s_partitions);
|
||||
} else if (s_size == 1) {
|
||||
// Trivial dimension removed.
|
||||
if (s_partitions != 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
target_dims_stack.push_back(t_size);
|
||||
} else if (s_size > t_size) {
|
||||
// Dimension split.
|
||||
if (s_size % t_size != 0 || t_size % s_partitions != 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
target_tile_assignment_dimensions.push_back(s_partitions);
|
||||
// We have part of the s_size unprocessed, so put it back to stack.
|
||||
source_dims_stack.push_back(s_size / t_size);
|
||||
sharding_tile_dims_stack.push_back(1);
|
||||
} else {
|
||||
// Dimension merge. Also merge the source dimension with the next, and
|
||||
// process it next time.
|
||||
if (s_size % s_partitions != 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
CHECK(!source_dims_stack.empty());
|
||||
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
|
||||
// If the next dimension to combine is sharded, we require that the
|
||||
// current dimension's shard size to be 1. Otherwise, the new shard
|
||||
// would be non-contiguous.
|
||||
return absl::nullopt;
|
||||
}
|
||||
source_dims_stack.back() *= s_size;
|
||||
sharding_tile_dims_stack.back() *= s_partitions;
|
||||
target_dims_stack.push_back(t_size);
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = sharding.tile_assignment();
|
||||
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
|
||||
absl::Span<const int64> dims) {
|
||||
CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
|
||||
CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
|
||||
// We optimize the tile assignment on the single dimension dim in a way to
|
||||
// minimize communication among devices caused by the reshard:
|
||||
// +---+---+ +---+---+ +-+-+-+-+
|
||||
// | | | | 0 | | | | | |
|
||||
// | 0 | 1 | +-------+ | | | | |
|
||||
// | | | reshape on | 1 | reshape on | | | | |
|
||||
// +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
|
||||
// | | | | 2 | | | | | |
|
||||
// | 2 | 3 | +-------+ | | | | |
|
||||
// | | | | 3 | | | | | |
|
||||
// +---+---+ +---+---+ +-+-+-+-+
|
||||
|
||||
std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
|
||||
// Handle ignore dimensions.
|
||||
std::vector<int64> ignore_sizes;
|
||||
int64 ignore_size = 1;
|
||||
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
|
||||
if (absl::c_find(dims, i) == dims.end()) {
|
||||
int64 size = sharding.tile_assignment().dim(i);
|
||||
ignore_sizes.push_back(size);
|
||||
tile_dims[i] = size;
|
||||
ignore_size *= size;
|
||||
}
|
||||
}
|
||||
|
||||
using Buckets = std::vector<std::vector<int64>>;
|
||||
Array<Buckets> buckets(ignore_sizes,
|
||||
Buckets(sharding.tile_assignment().dim(dim)));
|
||||
sharding.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> index, int64 device) {
|
||||
std::vector<int64> ignore_index;
|
||||
for (int64 i = 0; i < index.size(); ++i) {
|
||||
if (absl::c_find(dims, i) == dims.end()) {
|
||||
ignore_index.push_back(index[i]);
|
||||
}
|
||||
}
|
||||
buckets(ignore_index)[index[dim]].push_back(device);
|
||||
});
|
||||
std::vector<int64> devices;
|
||||
buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
|
||||
for (auto& bucket : buckets) {
|
||||
devices.insert(devices.end(), bucket.begin(), bucket.end());
|
||||
}
|
||||
});
|
||||
tile_dims[dim] = devices.size() / ignore_size;
|
||||
Array<int64> tile_assignment(tile_dims);
|
||||
tile_assignment.SetValues(devices);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
bool ContainsTileSharding(const HloModule& module) {
|
||||
for (const HloComputation* computation : module.computations()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->has_sharding() &&
|
||||
!instruction->sharding().IsTileMaximal()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
||||
const HloInstruction* hlo) {
|
||||
if (index_sharding.IsTileMaximal()) {
|
||||
return index_sharding;
|
||||
}
|
||||
|
||||
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
|
||||
std::vector<int64> output_tile_assignment_dims;
|
||||
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
|
||||
if (absl::c_binary_search(dnums.offset_dims(), i)) {
|
||||
output_tile_assignment_dims.push_back(1);
|
||||
} else {
|
||||
output_tile_assignment_dims.push_back(
|
||||
index_sharding.tile_assignment().dim(index_dim));
|
||||
index_dim++;
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
|
||||
new_tile_assignment.Reshape(output_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||
const HloInstruction* hlo) {
|
||||
if (output_sharding.IsTileMaximal()) {
|
||||
return output_sharding;
|
||||
}
|
||||
|
||||
const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
|
||||
std::vector<int64> index_tile_assignment_dims;
|
||||
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
|
||||
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
|
||||
index_tile_assignment_dims.push_back(
|
||||
output_sharding.tile_assignment().dim(i));
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
|
||||
new_tile_assignment.Reshape(index_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
|
||||
if (hlo.sharding().IsTileMaximal()) {
|
||||
return hlo.sharding();
|
||||
}
|
||||
|
||||
const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
|
||||
std::vector<int64> tile_assignment_dims(hlo.shape().rank());
|
||||
int64 num_elements = 1;
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
|
||||
tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
|
||||
num_elements *= hlo.sharding().tile_assignment().dim(i);
|
||||
} else {
|
||||
tile_assignment_dims[i] = 1;
|
||||
}
|
||||
}
|
||||
if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
|
||||
// Output sharding is only on non offset dimensions. We use output sharding
|
||||
// to shard this gather op directly.
|
||||
return hlo.sharding();
|
||||
}
|
||||
|
||||
if (num_elements == 1) {
|
||||
// Output sharding is only on offset dimensions. We do not shard this gather
|
||||
// op. Return a tile maximal sharding with the first device in output
|
||||
// sharding tile assignment.
|
||||
return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin());
|
||||
}
|
||||
|
||||
// Output sharding is on both offset and non offset dimensions. We shard the
|
||||
// gather op only on non offset dimensions.
|
||||
// For example:
|
||||
// - the gather op has sharding [2,2]{0,1,2,3},
|
||||
// - first dimension is non offset dimension,
|
||||
// - second dimension is offset dimension,
|
||||
// Then the result sharding will be [2,1]{0,2}.
|
||||
std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
|
||||
slice_limits(hlo.shape().rank());
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (!absl::c_binary_search(dnums.offset_dims(), i)) {
|
||||
slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
|
||||
} else {
|
||||
slice_limits[i] = 1;
|
||||
}
|
||||
}
|
||||
Array<int64> tile_assignment =
|
||||
hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
|
||||
const HloInstruction* hlo) {
|
||||
if (data_sharding.IsTileMaximal()) {
|
||||
return data_sharding;
|
||||
}
|
||||
|
||||
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
|
||||
std::vector<int64> index_tile_assignment_dims;
|
||||
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
|
||||
if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
|
||||
index_tile_assignment_dims.push_back(
|
||||
data_sharding.tile_assignment().dim(i));
|
||||
}
|
||||
}
|
||||
if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
|
||||
index_tile_assignment_dims.push_back(1);
|
||||
}
|
||||
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
|
||||
new_tile_assignment.Reshape(index_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
|
||||
const HloInstruction* hlo) {
|
||||
if (index_sharding.IsTileMaximal()) {
|
||||
return index_sharding;
|
||||
}
|
||||
|
||||
const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
|
||||
std::vector<int64> data_tile_assignment_dims;
|
||||
for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
|
||||
if (absl::c_binary_search(dnums.update_window_dims(), i)) {
|
||||
data_tile_assignment_dims.push_back(1);
|
||||
} else {
|
||||
data_tile_assignment_dims.push_back(
|
||||
index_sharding.tile_assignment().dim(index_dim));
|
||||
index_dim++;
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
|
||||
new_tile_assignment.Reshape(data_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
|
||||
const HloInstruction& hlo) {
|
||||
if (index_sharding.IsTileMaximal()) {
|
||||
return index_sharding;
|
||||
}
|
||||
|
||||
// Only shard on first "number of scatter_window_dims" dimensions.
|
||||
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
|
||||
int64 num_elements = 1;
|
||||
int64 index_dim = 0;
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
|
||||
num_elements *= index_sharding.tile_assignment().dim(index_dim);
|
||||
index_dim++;
|
||||
}
|
||||
}
|
||||
if (num_elements == index_sharding.tile_assignment().num_elements()) {
|
||||
// Index sharding is only on scatter_window_dims. We use this index sharding
|
||||
// directly.
|
||||
return index_sharding;
|
||||
}
|
||||
|
||||
// Index sharding is only on update_window_dims. We do not shard this scatter
|
||||
// op. Return a tile maximal sharding with the first device in index sharding
|
||||
// tile assignment.
|
||||
if (num_elements == 1) {
|
||||
return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin());
|
||||
}
|
||||
|
||||
const int64 index_rank = hlo.operand(1)->shape().rank();
|
||||
std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
|
||||
for (int64 i = 0; i < index_rank; ++i) {
|
||||
if (i < index_dim) {
|
||||
slice_limits[i] = index_sharding.tile_assignment().dim(i);
|
||||
} else {
|
||||
slice_limits[i] = 1;
|
||||
}
|
||||
}
|
||||
Array<int64> tile_assignment =
|
||||
index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
|
||||
const HloInstruction& hlo) {
|
||||
if (data_sharding.IsTileMaximal()) {
|
||||
return data_sharding;
|
||||
}
|
||||
|
||||
const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
|
||||
const int64 data_rank = hlo.operand(2)->shape().rank();
|
||||
std::vector<int64> tile_assignment_dims(data_rank, 1LL);
|
||||
int64 num_elements = 1;
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
|
||||
CHECK_LT(i, data_rank);
|
||||
tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
|
||||
num_elements *= data_sharding.tile_assignment().dim(i);
|
||||
}
|
||||
}
|
||||
if (num_elements == data_sharding.tile_assignment().num_elements()) {
|
||||
// Data sharding is only on scatter_window_dims. We use this data sharding
|
||||
// directly.
|
||||
return data_sharding;
|
||||
}
|
||||
|
||||
if (num_elements == 1) {
|
||||
// Data sharding is only on update_window_dims. We do not shard this
|
||||
// scatter op. Return a tile maximal sharding with the first device in
|
||||
// data sharding tile assignment.
|
||||
return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin());
|
||||
}
|
||||
|
||||
// Data sharding is on both update_window_dims and scatter_window_dims. We
|
||||
// shard the scatter op only on scatter_window_dims. For example:
|
||||
// - the scatter data has sharding [2,2]{0,1,2,3},
|
||||
// - first dimension is scatter_window_dims,
|
||||
// - second dimension is update_window_dims,
|
||||
// Then the result sharding will be [2,1]{0,2}.
|
||||
std::vector<int64> slice_starts(data_rank, 0LL);
|
||||
Array<int64> tile_assignment =
|
||||
data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
|
||||
IdentityValueAndHloOpcodeForScatterReduceComputation(
|
||||
const HloScatterInstruction& scatter) {
|
||||
auto computation = scatter.to_apply();
|
||||
// We only handle computations with 2 parameters and only 1 calculation.
|
||||
if (computation->instruction_count() != 3) {
|
||||
return Status(
|
||||
tensorflow::error::Code::INVALID_ARGUMENT,
|
||||
"Expected scatter reduce computation with 2 parameters and only 1 "
|
||||
"calculation");
|
||||
}
|
||||
|
||||
auto root_instruction = computation->root_instruction();
|
||||
if (root_instruction->opcode() == HloOpcode::kAdd ||
|
||||
root_instruction->opcode() == HloOpcode::kOr) {
|
||||
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
|
||||
scatter.shape().element_type())),
|
||||
root_instruction->opcode());
|
||||
} else if (root_instruction->opcode() == HloOpcode::kMultiply ||
|
||||
root_instruction->opcode() == HloOpcode::kAnd) {
|
||||
return std::make_pair(HloInstruction::CreateConstant(
|
||||
LiteralUtil::One(scatter.shape().element_type())),
|
||||
root_instruction->opcode());
|
||||
} else if (root_instruction->opcode() == HloOpcode::kMaximum) {
|
||||
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
|
||||
scatter.shape().element_type())),
|
||||
root_instruction->opcode());
|
||||
} else if (root_instruction->opcode() == HloOpcode::kMinimum) {
|
||||
return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
|
||||
scatter.shape().element_type())),
|
||||
root_instruction->opcode());
|
||||
}
|
||||
|
||||
return Status(tensorflow::error::Code::INVALID_ARGUMENT,
|
||||
"Expected scatter reduce computation which is "
|
||||
"add/or/multiply/add/min/max");
|
||||
}
|
||||
|
||||
std::vector<int64> DevicesForSharding(
|
||||
const HloSharding& sharding, const std::vector<int64>& available_devices) {
|
||||
std::vector<int64> devices;
|
||||
if (sharding.IsReplicated()) {
|
||||
for (int64 d : available_devices) {
|
||||
if (!HloSharding::IsReservedDevice(d)) {
|
||||
devices.push_back(d);
|
||||
}
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
for (int64 i : available_devices) {
|
||||
if (sharding.UsesDevice(i)) {
|
||||
devices.push_back(i);
|
||||
}
|
||||
}
|
||||
DCHECK(std::all_of(sharding.tile_assignment().begin(),
|
||||
sharding.tile_assignment().end(), [&](int64 device) {
|
||||
return std::find(available_devices.begin(),
|
||||
available_devices.end(),
|
||||
device) != available_devices.end();
|
||||
}));
|
||||
return devices;
|
||||
}
|
||||
|
||||
} // namespace hlo_sharding_util
|
||||
} // namespace xla
|
||||
143
tensorflow/compiler/xla/service/hlo_sharding_util.h
Normal file
143
tensorflow/compiler/xla/service/hlo_sharding_util.h
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
|
||||
namespace xla {
|
||||
namespace hlo_sharding_util {
|
||||
|
||||
// Given a map<device, occurrence_count>, selects the device with higher
|
||||
// occurrence count (if any). If top_count in not nullptr, it will receive the
|
||||
// count of the dominant device returned.
|
||||
absl::optional<int64> SelectDominantDevice(
|
||||
const std::map<int64, int64>& device_map, int64* top_count);
|
||||
|
||||
// Assigns all the instructions of a computation, to a given device.
|
||||
// This API does not recurse into called computations, and does not assign
|
||||
// instructions which already have sharding.
|
||||
Status AssignComputationDevice(HloComputation* computation, int64 device);
|
||||
|
||||
// Given an instruction container, returns the device which is most commonly
|
||||
// occurring among the instructions.
|
||||
absl::optional<int64> GetMostOccurringDevice(
|
||||
absl::Span<HloInstruction* const> instructions);
|
||||
|
||||
// Given a set of computations, tries to extract the dominant device. A device
|
||||
// is dominant if the combined occurrence among all the instructions of the
|
||||
// input computations, is greater/equal than/to dominant_factor (real number
|
||||
// from 0 to 1).
|
||||
// This API does not recurse into called computations.
|
||||
// If no device exists that satisfies the condition, the returned optional will
|
||||
// hold no value.
|
||||
StatusOr<absl::optional<int64>> GetDominantDevice(
|
||||
absl::Span<HloComputation* const> computations, double dominant_factor);
|
||||
|
||||
// Returns the HloSharding with the tile dimensions and tile assignment
|
||||
// transposed based on the specified dimension numbers. In case of a tile
|
||||
// maximal sharding returns the original sharding.
|
||||
HloSharding TransposeSharding(const HloSharding& sharding,
|
||||
const std::vector<int64>& dimensions);
|
||||
|
||||
// Returns the HloSharding with the tile shape reshaped based on the source and
|
||||
// target shapes and the tile assignment adjusted to correspond to the new tile
|
||||
// shape or absl::nullopt if the resulting reshape would create an invalid
|
||||
// sharding (non continuous or non uniformly sized tiles). In case of a tile
|
||||
// maximal sharding returns the original sharding.
|
||||
absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
|
||||
const Shape& target_shape,
|
||||
const HloSharding& sharding);
|
||||
|
||||
// Returns a sharding tiled on unique dimension dim by reshaping the tile
|
||||
// assignment of the sharding argument. Only dimensions in the dims span
|
||||
// argument are considered for reshaping, the others are ignored.
|
||||
// Assumptions: sharding is tile sharded, and dim must be included in dims.
|
||||
HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim,
|
||||
absl::Span<const int64> dims);
|
||||
|
||||
// Returns true if the provided module includes one or more instructions with
|
||||
// a tile sharding.
|
||||
bool ContainsTileSharding(const HloModule& module);
|
||||
|
||||
// Returns the preferred output sharding for a gather op based on the sharding
|
||||
// of the indces.
|
||||
HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
||||
const HloInstruction* hlo);
|
||||
|
||||
// Returns the preferred index sharding for a gather op based on the sharding
|
||||
// of the output.
|
||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||
const HloInstruction* hlo);
|
||||
|
||||
// Returns a new HloSharding for a gather op so that only non offset dimensions
|
||||
// are sharded. Assume "result" is returned by this function. It is ensured that
|
||||
// "GetIndexSharding(result, hlo)" will have the same number of elements as
|
||||
// "result".
|
||||
HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
|
||||
|
||||
// Returns the preferred index sharding for a scatter op based on the sharding
|
||||
// of the data.
|
||||
HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
|
||||
const HloInstruction* hlo);
|
||||
|
||||
// Returns the preferred data sharding for a scatter op based on the sharding
|
||||
// of the index.
|
||||
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
|
||||
const HloInstruction* hlo);
|
||||
|
||||
// Returns a new index sharding for a scatter op so that we only shard on first
|
||||
// "number of scatter_window_dims" dimensions. Assume "result" is returned by
|
||||
// this function. It is ensured that "ScatterDataSharding(result, hlo)" will
|
||||
// have the same number of elements as "result".
|
||||
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
|
||||
const HloInstruction& hlo);
|
||||
|
||||
// Returns a new data sharding for a scatter op so that we only shard on
|
||||
// scatter_window_dims. Assume "result" is returned by this function. It is
|
||||
// ensured that "ScatterIndexSharding(result, hlo)" will have the same number of
|
||||
// elements as "result".
|
||||
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
|
||||
const HloInstruction& hlo);
|
||||
|
||||
// Returns an identity value and an HloOpcode for reduce computation of scatter
|
||||
// instruction.
|
||||
// - If computation is add/or, return 0/false with corresponding op code;
|
||||
// - If computation is multiply/and, return 1/true with corresponding op code.
|
||||
// - If computation is min/max, return max value/min value with corresponding op
|
||||
// code.
|
||||
// - Otherwise, return error status.
|
||||
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
|
||||
IdentityValueAndHloOpcodeForScatterReduceComputation(
|
||||
const HloScatterInstruction& scatter);
|
||||
|
||||
// Given a sharding and a list of devices in the topology, return a
|
||||
// list of the devices that `sharding` applies to.
|
||||
std::vector<int64> DevicesForSharding(
|
||||
const HloSharding& sharding, const std::vector<int64>& available_devices);
|
||||
|
||||
} // namespace hlo_sharding_util
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
|
||||
206
tensorflow/compiler/xla/service/hlo_sharding_util_test.cc
Normal file
206
tensorflow/compiler/xla/service/hlo_sharding_util_test.cc
Normal file
@ -0,0 +1,206 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace hlo_sharding_util {
|
||||
namespace {
|
||||
|
||||
TEST(HloShardingUtilTest, TransposeShardingReplicated) {
|
||||
EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}),
|
||||
HloSharding::Replicate());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, TransposeShardingTiled) {
|
||||
HloSharding input = HloSharding::Tile(Array4D<int64>({{{{0, 1}}, {{2, 3}}}}));
|
||||
HloSharding output =
|
||||
HloSharding::Tile(Array4D<int64>({{{{0}, {2}}}, {{{1}, {3}}}}));
|
||||
EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingMaximal) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2});
|
||||
HloSharding sharding = HloSharding::AssignDevice(7);
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2});
|
||||
HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_FALSE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7});
|
||||
HloSharding input_sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
|
||||
HloSharding output_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7});
|
||||
HloSharding input_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}}));
|
||||
HloSharding output_sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7});
|
||||
HloSharding input_sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
|
||||
HloSharding output_sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14});
|
||||
Array<int64> sharding_array({2, 1, 1, 1});
|
||||
sharding_array(0, 0, 0, 0) = 0;
|
||||
sharding_array(1, 0, 0, 0) = 1;
|
||||
HloSharding sharding = HloSharding::Tile(sharding_array);
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7});
|
||||
HloSharding input_sharding =
|
||||
HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}}));
|
||||
HloSharding output_sharding =
|
||||
HloSharding::Tile(Array4D<int64>({{{{0}}, {{1}}}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1});
|
||||
HloSharding input_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}}));
|
||||
HloSharding output_sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, input_sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), output_sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1});
|
||||
HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}}));
|
||||
absl::optional<HloSharding> result = ReshapeSharding(shape, shape, sharding);
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(result.value(), sharding);
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeShardingScalar) {
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}}));
|
||||
absl::optional<HloSharding> result =
|
||||
ReshapeSharding(input_shape, output_shape, sharding);
|
||||
EXPECT_FALSE(result.has_value());
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) {
|
||||
HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}}));
|
||||
HloSharding result =
|
||||
ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1});
|
||||
EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0}, {1}, {2}, {3}}));
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) {
|
||||
HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}}));
|
||||
HloSharding result =
|
||||
ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1});
|
||||
EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0, 2, 1, 3}}));
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) {
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
|
||||
HloSharding result =
|
||||
ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2});
|
||||
EXPECT_EQ(
|
||||
result.tile_assignment(),
|
||||
Array3D<int64>({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}}));
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) {
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
|
||||
HloSharding result =
|
||||
ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2});
|
||||
EXPECT_EQ(result.tile_assignment(),
|
||||
Array3D<int64>({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}}));
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) {
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
|
||||
HloSharding result =
|
||||
ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2});
|
||||
EXPECT_EQ(result.tile_assignment(),
|
||||
Array3D<int64>({{{0, 2, 4, 6, 1, 3, 5, 7}}}));
|
||||
}
|
||||
|
||||
TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) {
|
||||
// Tile sharding in batch dimension, i.e.
|
||||
// sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}.
|
||||
HloSharding sharding =
|
||||
HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}));
|
||||
// Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0.
|
||||
HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2,
|
||||
/*dims=*/{1, 2});
|
||||
// Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two
|
||||
// non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually
|
||||
// reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}.
|
||||
EXPECT_EQ(result.tile_assignment(),
|
||||
Array3D<int64>({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace hlo_sharding_util
|
||||
} // namespace xla
|
||||
69
tensorflow/compiler/xla/service/spmd/BUILD
Normal file
69
tensorflow/compiler/xla/service/spmd/BUILD
Normal file
@ -0,0 +1,69 @@
|
||||
# Description: SPMD partitioning pass.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = [
|
||||
"//tensorflow/compiler/xla:friends",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "spmd_partitioner",
|
||||
srcs = [
|
||||
"spmd_partitioner.cc",
|
||||
"spmd_partitioner_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"spmd_partitioner.h",
|
||||
"spmd_partitioner_util.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_cse",
|
||||
"//tensorflow/compiler/xla/service:hlo_dce",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_query",
|
||||
"//tensorflow/compiler/xla/service:hlo_sharding_util",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/core/platform:numbers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "spmd_partitioner_test",
|
||||
srcs = ["spmd_partitioner_test.cc"],
|
||||
deps = [
|
||||
":spmd_partitioner",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
4655
tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
Normal file
4655
tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
Normal file
File diff suppressed because it is too large
Load Diff
435
tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
Normal file
435
tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
Normal file
@ -0,0 +1,435 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
struct SpmdPartitionerOptions {
|
||||
// Always exchange halo on LHS for all convolutions. If false, backprop filter
|
||||
// convolution exchanges halo on RHS.
|
||||
bool conv_halo_exchange_always_on_lhs = true;
|
||||
|
||||
// The number of instructions to be reported for the highest memory profile
|
||||
// instructions.
|
||||
int64 report_instruction_count = 5;
|
||||
|
||||
// The minimum size in MiB of an einsum operand to be considered using
|
||||
// windowed implementation in an HLO loop.
|
||||
int64 threshold_for_windowed_einsum_mib = 256;
|
||||
|
||||
// Whether the entry computations' signature could change after partitioning.
|
||||
bool allow_module_signature_change = false;
|
||||
};
|
||||
|
||||
// Class to wrap the computation builder to capture information during SPMD
|
||||
// transformation.
|
||||
class SpmdBuilder : public HloComputation::Builder {
|
||||
public:
|
||||
SpmdBuilder(const std::string& name, HloInstruction* hlo)
|
||||
: HloComputation::Builder(name) {
|
||||
visiting_hlo_ = hlo;
|
||||
}
|
||||
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
|
||||
|
||||
const std::vector<HloInstruction*>& derived_instructions(
|
||||
HloInstruction* hlo) {
|
||||
return instructions_.at(hlo);
|
||||
}
|
||||
|
||||
void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; }
|
||||
|
||||
HloInstruction* visiting_hlo() const { return visiting_hlo_; }
|
||||
|
||||
private:
|
||||
// Currently visiting instruction.
|
||||
HloInstruction* visiting_hlo_;
|
||||
|
||||
// Map from the currently visiting (old) instruction to new instructions
|
||||
// created during SPMD partitioning.
|
||||
HloInstructionMap<std::vector<HloInstruction*>> instructions_;
|
||||
};
|
||||
|
||||
// A set of functions that create the cross-partition collective ops.
|
||||
struct SPMDCollectiveOpsCreator {
|
||||
// Function used to create a partition ID HLO.
|
||||
std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
|
||||
|
||||
// Function used to create a cross-partition all-reduce HLO.
|
||||
std::function<HloInstruction*(SpmdBuilder*, HloInstruction* operand,
|
||||
HloComputation* reduction, int64 channel_id)>
|
||||
create_cross_partition_all_reduce;
|
||||
|
||||
// Function used to create a cross-partition collective-permute HLO.
|
||||
std::function<HloInstruction*(
|
||||
SpmdBuilder*, HloInstruction* operand,
|
||||
std::vector<std::pair<int64, int64>>& src_dst_pairs,
|
||||
int64 next_channel_id)>
|
||||
create_cross_partition_collective_permute;
|
||||
|
||||
// Function used to create a cross-partition all-to-all HLO.
|
||||
std::function<HloInstruction*(
|
||||
SpmdBuilder*, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
|
||||
absl::optional<int64> split_dimension)>
|
||||
create_cross_partition_all_to_all;
|
||||
};
|
||||
|
||||
// Logger to report memory usage during SPMD partitioning.
|
||||
class SpmdLogger {
|
||||
public:
|
||||
explicit SpmdLogger(int64 report_instruction_count)
|
||||
: report_instruction_count_(report_instruction_count) {}
|
||||
static std::string ReportBeforePartition(const HloModule& module,
|
||||
int64 report_instruction_count);
|
||||
static std::string ReportAfterPartition(const HloModule& module,
|
||||
int64 report_instruction_count);
|
||||
|
||||
// Registers the logging for the groups of instructions created to transform
|
||||
// the given hlo.
|
||||
void RegisterLogEntry(HloInstruction* hlo,
|
||||
const std::vector<HloInstruction*>& group);
|
||||
|
||||
std::string MakeReport();
|
||||
|
||||
private:
|
||||
template <typename F>
|
||||
static std::string ReportMemoryUsage(const HloModule& module, const F& filter,
|
||||
int64 report_instruction_count);
|
||||
|
||||
// A vector of logging messages (one for each original HLO instruction), where
|
||||
// the first integer of the pair represents the size of the HBM used.
|
||||
std::vector<std::pair<int64, std::string>> entries_;
|
||||
|
||||
int64 report_instruction_count_;
|
||||
};
|
||||
|
||||
class SpmdPartitioningVisitor;
|
||||
|
||||
class SpmdPartitioner : public HloModulePass {
|
||||
public:
|
||||
SpmdPartitioner(int64 num_partitions, int64 num_replicas,
|
||||
SpmdPartitionerOptions options);
|
||||
SpmdPartitioner(int64 num_partitions, int64 num_replicas,
|
||||
SpmdPartitionerOptions options,
|
||||
SPMDCollectiveOpsCreator collective_ops_creator)
|
||||
: num_partitions_(num_partitions),
|
||||
num_replicas_(num_replicas),
|
||||
options_(std::move(options)),
|
||||
collective_ops_creator_(std::move(collective_ops_creator)) {}
|
||||
absl::string_view name() const override { return "spmd-partitioning"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Transforms the given computation with SPMD instructions, replacing it with
|
||||
// a new computation.
|
||||
StatusOr<bool> PartitionComputation(HloComputation* computation,
|
||||
const HloSharding& root_sharding,
|
||||
int64* next_channel_id,
|
||||
SpmdLogger* logger);
|
||||
|
||||
protected:
|
||||
virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
|
||||
HloComputation* computation, int64 num_partitions, int64 num_replicas,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdLogger* logger,
|
||||
SpmdPartitionerOptions options);
|
||||
|
||||
private:
|
||||
// Verify that the sharding of instructions in the module are valid, and also
|
||||
// fill in missing sharding information.
|
||||
Status PreprocessSharding(HloModule* module);
|
||||
|
||||
const int64 num_partitions_;
|
||||
const int64 num_replicas_;
|
||||
|
||||
SpmdPartitionerOptions options_;
|
||||
SPMDCollectiveOpsCreator collective_ops_creator_;
|
||||
};
|
||||
|
||||
// Class describes partition state of the data represented by an HLO created
|
||||
// during SPMD partitioning pass.
|
||||
//
|
||||
// Data on some devices may include padding region, if the base (full) shape
|
||||
// could not be evenly partitioned.
|
||||
class PartitionedHlo {
|
||||
public:
|
||||
// Return value for ReshardAsWindowedInput which describes the resharded HLO,
|
||||
// the window for the user on the shard, and if necessary, the dynamic slice
|
||||
// offsets to be applied to the output of the op being sharded.
|
||||
struct WindowedInputShardReturnValue {
|
||||
HloInstruction* sharded_input;
|
||||
Window shard_window;
|
||||
absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output;
|
||||
};
|
||||
// A cache for resharding each partitioned HLO.
|
||||
struct ReshardCache {
|
||||
struct PerHloCache {
|
||||
std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
|
||||
std::vector<
|
||||
std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
|
||||
window_reshard_cache;
|
||||
};
|
||||
std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
|
||||
};
|
||||
struct PartitioningState {
|
||||
SpmdBuilder* b;
|
||||
HloModule* module;
|
||||
int64 num_replicas;
|
||||
HloInstruction* partition_id;
|
||||
SPMDCollectiveOpsCreator collective_ops_creator;
|
||||
int64* next_channel_id;
|
||||
ReshardCache* reshard_cache;
|
||||
};
|
||||
PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
|
||||
: hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
|
||||
CHECK(hlo->has_sharding())
|
||||
<< "PartitionedHlo is missing sharding:" << hlo->ToString();
|
||||
// If the tuple shape instruction does not have a tuple sharding, reassign
|
||||
// to use the tuple sharding. Reshard() implementation assumes this.
|
||||
if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) {
|
||||
hlo_->set_sharding(
|
||||
hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie());
|
||||
}
|
||||
}
|
||||
|
||||
// Reshards the current SPMD instruction to a new sharding. Could only modify
|
||||
// the reshard cache.
|
||||
PartitionedHlo Reshard(const HloSharding& target);
|
||||
|
||||
// Pads the garbage area of the output with the provided value.
|
||||
PartitionedHlo PadWithValue(HloInstruction* pad_value) const;
|
||||
|
||||
// Returns the SPMD instruction.
|
||||
HloInstruction* hlo() const { return hlo_; }
|
||||
|
||||
// Returns the sharding of the SPMD instruction.
|
||||
const HloSharding& sharding() const { return hlo_->sharding(); }
|
||||
|
||||
// Original full shape of the data.
|
||||
const Shape& base_shape() const { return base_shape_; }
|
||||
|
||||
int64 NewChannel() const { return (*state_.next_channel_id)++; }
|
||||
|
||||
// Reshards the HLO to a usable partitioned input for a windowed user. Could
|
||||
// only modify the reshard cache.
|
||||
absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput(
|
||||
const Window& window, const HloSharding& target,
|
||||
HloInstruction* pad_value, bool mask_invalid_region = true);
|
||||
|
||||
private:
|
||||
// Same as Reshard except that it does not explicitly modify the reshard
|
||||
// cache, although it would indirectly modify by calling Replicate().
|
||||
PartitionedHlo ReshardNoCache(const HloSharding& target);
|
||||
|
||||
// Helper function to replicate the data on all devices. Could only modify
|
||||
// the reshard cache.
|
||||
PartitionedHlo Replicate();
|
||||
|
||||
// Helper function to broadcast data from a single device to all devices.
|
||||
PartitionedHlo Broadcast() const;
|
||||
|
||||
// Helper function to reshard the tensor using AllToAll (instead of the
|
||||
// default of Replicate followed by Slice).
|
||||
PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const;
|
||||
|
||||
// Helper function to reshard the tensor using CollectivePermute.
|
||||
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
|
||||
|
||||
// SPMD instruction.
|
||||
HloInstruction* hlo_;
|
||||
|
||||
// The original shape of the data before SPMD transformation is applied.
|
||||
Shape base_shape_;
|
||||
|
||||
PartitioningState state_;
|
||||
};
|
||||
|
||||
struct DotGeneralDimsMapping {
|
||||
// The dimension numbers for the operands and output corresponding to a
|
||||
// logical dimension (e.g., batch, contracting, non-contracting). If an
|
||||
// operand or the output doesn't have the logical dimension, it is set to
|
||||
// -1.
|
||||
struct DimsMapping {
|
||||
int64 lhs;
|
||||
int64 rhs;
|
||||
int64 output;
|
||||
};
|
||||
std::vector<DimsMapping> batch_dims;
|
||||
std::vector<DimsMapping> contracting_dims;
|
||||
std::vector<DimsMapping> lhs_non_contracting_dims;
|
||||
std::vector<DimsMapping> rhs_non_contracting_dims;
|
||||
};
|
||||
|
||||
class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
SpmdPartitioningVisitor(
|
||||
HloComputation* computation, int64 num_partitions, int64 num_replicas,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdLogger* logger,
|
||||
SpmdPartitionerOptions options, SpmdPartitioner* partitioner);
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
Status HandleAllReduce(HloInstruction* hlo) override;
|
||||
Status HandleBroadcast(HloInstruction* hlo) override;
|
||||
Status HandleConstant(HloInstruction* hlo) override;
|
||||
Status HandleCustomCall(HloInstruction* hlo) override;
|
||||
Status HandleDot(HloInstruction* hlo) override;
|
||||
Status HandleDynamicSlice(HloInstruction* hlo) override;
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
|
||||
Status HandleGather(HloInstruction* hlo) override;
|
||||
Status HandleGetTupleElement(HloInstruction* hlo) override;
|
||||
Status HandleInfeed(HloInstruction* hlo) override;
|
||||
Status HandleOutfeed(HloInstruction* hlo) override;
|
||||
Status HandlePad(HloInstruction* hlo) override;
|
||||
Status HandleParameter(HloInstruction* hlo) override;
|
||||
Status HandleReduce(HloInstruction* hlo) override;
|
||||
Status HandleReverse(HloInstruction* hlo) override;
|
||||
Status HandleWhile(HloInstruction* hlo) override;
|
||||
Status HandleConditional(HloInstruction* hlo) override;
|
||||
Status HandleReduceWindow(HloInstruction* hlo) override;
|
||||
Status HandleSelectAndScatter(HloInstruction* hlo) override;
|
||||
Status HandleTuple(HloInstruction* hlo) override;
|
||||
Status HandleRng(HloInstruction* hlo) override;
|
||||
Status HandleConvolution(HloInstruction* hlo) override;
|
||||
Status HandleConcatenate(HloInstruction* hlo) override;
|
||||
Status HandleScatter(HloInstruction* hlo) override;
|
||||
Status HandleSlice(HloInstruction* hlo) override;
|
||||
Status HandleSort(HloInstruction* hlo) override;
|
||||
Status HandleTranspose(HloInstruction* hlo) override;
|
||||
Status HandleReshape(HloInstruction* hlo) override;
|
||||
Status HandleIota(HloInstruction* hlo) override;
|
||||
Status HandlePartitionId(HloInstruction* hlo) override;
|
||||
|
||||
// Handles convolution where both LHS and RHS operands are tiled.
|
||||
Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo);
|
||||
|
||||
// Implementation of dot partitioning given DotGeneralDimsMapping.
|
||||
Status HandleDotHelper(
|
||||
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
|
||||
|
||||
// Common handle for elementwise HLOs.
|
||||
Status HandleElementwise(HloInstruction* hlo);
|
||||
|
||||
// Common handle for HLOs that runs on a single device.
|
||||
Status HandleSingleDevice(const HloInstruction* hlo);
|
||||
|
||||
// Returns the PartitionedHlo that corresponds to the original hlo.
|
||||
PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) {
|
||||
CHECK_EQ(partitioned_instructions_.count(hlo), 1);
|
||||
return partitioned_instructions_.find(hlo)->second;
|
||||
}
|
||||
|
||||
// Sets the PartitionedHlo for the original hlo.
|
||||
void SetPartitionedHlo(const HloInstruction* hlo,
|
||||
const PartitionedHlo& partitioned_hlo) {
|
||||
CHECK_EQ(partitioned_instructions_.count(hlo), 0);
|
||||
partitioned_instructions_.emplace(hlo, partitioned_hlo);
|
||||
changed_ = true;
|
||||
}
|
||||
|
||||
// Convenient wrapper that creates PartitionedHlo from the result of the func
|
||||
// and maps it to the given original hlo.
|
||||
void SetPartitionedHlo(const HloInstruction* hlo,
|
||||
const std::function<HloInstruction*()>& func) {
|
||||
HloInstruction* new_hlo = func();
|
||||
new_hlo->set_sharding(hlo->sharding());
|
||||
new_hlo->set_metadata(hlo->metadata());
|
||||
SetPartitionedHlo(
|
||||
hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
|
||||
changed_ = true;
|
||||
}
|
||||
|
||||
int64 NewChannel() { return (*next_channel_id_)++; }
|
||||
|
||||
PartitionedHlo::PartitioningState MakePartitioningState() {
|
||||
return PartitionedHlo::PartitioningState{
|
||||
.b = &b_,
|
||||
.module = module_,
|
||||
.num_replicas = num_replicas_,
|
||||
.partition_id = partition_id_,
|
||||
.collective_ops_creator = collective_ops_creator_,
|
||||
.next_channel_id = next_channel_id_,
|
||||
.reshard_cache = &reshard_cache_};
|
||||
}
|
||||
|
||||
SpmdBuilder* builder() { return &b_; }
|
||||
|
||||
StatusOr<bool> DoPartition(HloComputation* computation,
|
||||
const HloSharding& root_sharding);
|
||||
|
||||
private:
|
||||
Status Preprocess(HloInstruction* hlo) override;
|
||||
Status Postprocess(HloInstruction* hlo) override;
|
||||
|
||||
// Performs code motion for windowed dot-general loops in
|
||||
// windowed_dot_general_loops_. Invoked after the visitor finishes traversing
|
||||
// the graph.
|
||||
Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation);
|
||||
|
||||
bool changed_;
|
||||
HloModule* module_;
|
||||
int64 num_partitions_;
|
||||
int64 num_replicas_;
|
||||
|
||||
SPMDCollectiveOpsCreator collective_ops_creator_;
|
||||
|
||||
// Tracks the next channel id to use for cross-partition all-reduce.
|
||||
int64* next_channel_id_;
|
||||
SpmdBuilder b_;
|
||||
|
||||
HloInstruction* partition_id_;
|
||||
|
||||
PartitionedHlo::ReshardCache reshard_cache_;
|
||||
|
||||
// Mapping from the instruction in the original computation to the new SPMD
|
||||
// partitioned instruction.
|
||||
ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
|
||||
|
||||
// Information about a loop created for windowed dot-general. Used when
|
||||
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
|
||||
// finishes traversing the graph.
|
||||
struct WindowedDotGeneralLoop {
|
||||
HloInstruction* while_loop;
|
||||
int64 windowed_operand;
|
||||
bool windowed_in_contracting_dims;
|
||||
bool windowed_in_batch_dims;
|
||||
};
|
||||
std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
|
||||
|
||||
HloInstruction* visiting_hlo_;
|
||||
SpmdLogger* logger_;
|
||||
const SpmdPartitionerOptions options_;
|
||||
SpmdPartitioner* partitioner_;
|
||||
};
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
|
||||
3191
tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
Normal file
3191
tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
662
tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
Normal file
662
tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
Normal file
@ -0,0 +1,662 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
bool HasReplicatedSharding(const HloSharding& sharding) {
|
||||
if (sharding.IsTuple()) {
|
||||
return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
|
||||
}
|
||||
return sharding.IsReplicated();
|
||||
}
|
||||
|
||||
HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
|
||||
if (shape.IsTuple()) {
|
||||
std::vector<HloInstruction*> elements;
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
elements.push_back(
|
||||
CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b));
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateTuple(elements));
|
||||
}
|
||||
|
||||
if (shape.IsToken()) {
|
||||
return b->AddInstruction(HloInstruction::CreateToken());
|
||||
}
|
||||
auto zero = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
|
||||
return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
|
||||
}
|
||||
|
||||
HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
|
||||
HloComputation::Builder sum_b("add");
|
||||
auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
|
||||
auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
|
||||
if (type == PRED) {
|
||||
sum_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
|
||||
} else {
|
||||
sum_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
|
||||
}
|
||||
HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
|
||||
return reduction;
|
||||
}
|
||||
|
||||
bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
|
||||
if (sharding.IsTuple()) {
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
|
||||
sharding.GetSubSharding(shape, {i}))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return sharding.IsReplicated();
|
||||
}
|
||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
||||
if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
|
||||
if (sharding.IsTuple()) {
|
||||
std::vector<Shape> subshapes;
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
subshapes.push_back(
|
||||
MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
|
||||
sharding.GetSubSharding(shape, {i})));
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(subshapes);
|
||||
}
|
||||
return sharding.TileShape(shape);
|
||||
}
|
||||
|
||||
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
int64 partition_id) {
|
||||
if (sharding.IsTuple()) {
|
||||
std::vector<Shape> subshapes;
|
||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
|
||||
ShapeUtil::GetTupleElementShape(shape, i),
|
||||
sharding.GetSubSharding(shape, {i}), partition_id));
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(subshapes);
|
||||
}
|
||||
|
||||
auto partition_shape = shape;
|
||||
std::vector<int64> tile_offset =
|
||||
sharding.TileOffsetForDevice(shape, partition_id);
|
||||
std::vector<int64> tile_limit =
|
||||
sharding.TileLimitForDevice(shape, partition_id);
|
||||
for (int64 i = 0; i < tile_offset.size(); ++i) {
|
||||
if (sharding.UsesDevice(partition_id)) {
|
||||
partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
|
||||
} else {
|
||||
partition_shape.set_dimensions(i, 0);
|
||||
}
|
||||
}
|
||||
return partition_shape;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
HloInstruction* partition_id,
|
||||
SpmdBuilder* b) {
|
||||
CHECK(!shape.IsTuple());
|
||||
|
||||
Array2D<int32> offset_array(
|
||||
{sharding.tile_assignment().num_elements(), shape.rank()});
|
||||
offset_array.Each([&](int64 i, int64 j, int32* value) {
|
||||
*value = sharding.TileOffsetForDevice(shape, i)[j];
|
||||
});
|
||||
auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2FromArray2D(offset_array)));
|
||||
std::vector<HloInstruction*> offsets;
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (sharding.tile_assignment().dim(i) == 1) {
|
||||
offsets.push_back(b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
|
||||
} else {
|
||||
auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
ShapeUtil::MakeShape(S32, {1, 1}), offset_table,
|
||||
{partition_id, b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(i)))},
|
||||
{1, 1}));
|
||||
offsets.push_back(b->AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
|
||||
}
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
|
||||
const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
|
||||
CHECK(!sharding.IsTileMaximal());
|
||||
auto table_shape =
|
||||
ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions());
|
||||
return MakePartitionOffsets(table_shape, sharding, partition_id, b);
|
||||
}
|
||||
|
||||
HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
|
||||
SpmdBuilder* b, HloComputation* computation) {
|
||||
CHECK(b == nullptr || computation == nullptr);
|
||||
if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
|
||||
return hlo;
|
||||
}
|
||||
PaddingConfig padding_config;
|
||||
for (int64 i = 0; i < padded_shape.rank(); ++i) {
|
||||
auto padding_config_dim = padding_config.add_dimensions();
|
||||
padding_config_dim->set_edge_padding_low(0);
|
||||
padding_config_dim->set_interior_padding(0);
|
||||
padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
|
||||
hlo->shape().dimensions(i));
|
||||
}
|
||||
auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
|
||||
if (b == nullptr) {
|
||||
return computation->AddInstruction(std::move(to_add));
|
||||
}
|
||||
return b->AddInstruction(std::move(to_add));
|
||||
};
|
||||
auto zero = add_hlo(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
return add_hlo(
|
||||
HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config));
|
||||
}
|
||||
|
||||
Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
|
||||
const HloSharding& sharding) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return base_shape;
|
||||
}
|
||||
if (EvenlyPartitions(base_shape, sharding)) {
|
||||
return base_shape;
|
||||
}
|
||||
auto shard_shape = MakePartitionedShape(base_shape, sharding);
|
||||
Shape padded_base_shape = base_shape;
|
||||
for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
|
||||
padded_base_shape.set_dimensions(
|
||||
i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
|
||||
}
|
||||
return padded_base_shape;
|
||||
}
|
||||
|
||||
HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
|
||||
HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) {
|
||||
auto padded_base_shape =
|
||||
GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
|
||||
if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
|
||||
return hlo;
|
||||
}
|
||||
return PadToShape(hlo, padded_base_shape, b);
|
||||
}
|
||||
|
||||
absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int64 dim = -1;
|
||||
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
|
||||
if (sharding.tile_assignment().dim(i) > 1) {
|
||||
if (dim != -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
dim = i;
|
||||
}
|
||||
}
|
||||
CHECK_NE(dim, -1);
|
||||
return dim;
|
||||
}
|
||||
|
||||
MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
|
||||
int64 multiplier, int64 offset, int64 divisor)
|
||||
: multiplier_(multiplier), offset_(offset), divisor_(divisor) {
|
||||
CHECK_GT(divisor_, 0);
|
||||
Simplify();
|
||||
}
|
||||
|
||||
OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
|
||||
const MultiplyAddDivideOffsetCalculation& other) const {
|
||||
if (divisor_ == 1 && other.divisor_ == 1) {
|
||||
return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
|
||||
}
|
||||
return OffsetCalculation(HloOpcode::kSubtract, *this, other);
|
||||
}
|
||||
|
||||
void MultiplyAddDivideOffsetCalculation::Simplify() {
|
||||
// We could simplify the calculation when multiplier is a multiple of
|
||||
// divisor_. However, when offset_ is not a multiple of divisor_, we must
|
||||
// make sure that offset_ and multiplier_ are both non-negative or both
|
||||
// non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1.
|
||||
if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
|
||||
(offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
|
||||
multiplier_ /= divisor_;
|
||||
offset_ /= divisor_;
|
||||
divisor_ = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const {
|
||||
return (shard_ordinal * multiplier_ + offset_) / divisor_;
|
||||
}
|
||||
|
||||
HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
|
||||
HloInstruction* shard_ordinal, SpmdBuilder* b) const {
|
||||
auto scalar_shape = ShapeUtil::MakeShape(S32, {});
|
||||
if (multiplier_ == 0) {
|
||||
return b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(offset_ / divisor_)));
|
||||
}
|
||||
HloInstruction* result = shard_ordinal;
|
||||
if (multiplier_ != 1) {
|
||||
result = b->AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kMultiply, shard_ordinal,
|
||||
b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(multiplier_)))));
|
||||
}
|
||||
if (offset_ != 0) {
|
||||
auto offset = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_)));
|
||||
result = b->AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, result, offset));
|
||||
}
|
||||
if (divisor_ != 1) {
|
||||
auto divisor = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_)));
|
||||
result = b->AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kDivide, result, divisor));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int64 MultiplyAddDivideOffsetCalculation::MaxInRange(
|
||||
int64 start_ordinal, int64 limit_ordinal) const {
|
||||
int64 max = Calculate(start_ordinal);
|
||||
for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
|
||||
max = std::max(max, Calculate(i));
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
OffsetCalculation& OffsetCalculation::operator=(
|
||||
const OffsetCalculation& other) {
|
||||
opcode_ = other.opcode_;
|
||||
copy_from_ = other.copy_from_;
|
||||
if (opcode_ != HloOpcode::kCopy) {
|
||||
lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_);
|
||||
rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool OffsetCalculation::IsConstant() const {
|
||||
if (opcode_ == HloOpcode::kCopy) {
|
||||
return copy_from_.IsConstant();
|
||||
}
|
||||
if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
|
||||
return true;
|
||||
}
|
||||
return lhs_->IsConstant() && rhs_->IsConstant();
|
||||
}
|
||||
|
||||
OffsetCalculation OffsetCalculation::operator-(
|
||||
const OffsetCalculation& other) const {
|
||||
if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
|
||||
return copy_from_ - other.copy_from_;
|
||||
}
|
||||
return OffsetCalculation(HloOpcode::kSubtract, *this, other);
|
||||
}
|
||||
|
||||
bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
|
||||
if (opcode_ != other.opcode_) {
|
||||
return false;
|
||||
}
|
||||
if (opcode_ == HloOpcode::kCopy) {
|
||||
return copy_from_ == other.copy_from_;
|
||||
}
|
||||
return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
|
||||
}
|
||||
|
||||
int64 OffsetCalculation::Calculate(int64 shard_ordinal) const {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kCopy:
|
||||
return copy_from_.Calculate(shard_ordinal);
|
||||
case HloOpcode::kSubtract:
|
||||
return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
|
||||
case HloOpcode::kMultiply:
|
||||
return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
|
||||
default:
|
||||
LOG(FATAL) << "Should not happen";
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
|
||||
SpmdBuilder* b) const {
|
||||
if (opcode_ == HloOpcode::kCopy) {
|
||||
return copy_from_.Calculate(shard_ordinal, b);
|
||||
}
|
||||
auto lhs = lhs_->Calculate(shard_ordinal, b);
|
||||
auto rhs = rhs_->Calculate(shard_ordinal, b);
|
||||
return b->AddInstruction(
|
||||
HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
|
||||
}
|
||||
|
||||
int64 OffsetCalculation::MaxInRange(int64 start_ordinal,
|
||||
int64 limit_ordinal) const {
|
||||
if (IsConstant()) {
|
||||
return Calculate(start_ordinal);
|
||||
}
|
||||
if (opcode_ == HloOpcode::kCopy) {
|
||||
return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
|
||||
}
|
||||
int64 max = Calculate(start_ordinal);
|
||||
for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) {
|
||||
max = std::max(max, Calculate(i));
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
absl::optional<HloInstruction*> ExchangeHalo(
|
||||
HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
|
||||
const OffsetCalculation& right_halo_size_function, int64 dim,
|
||||
const HloSharding& target,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b) {
|
||||
int64 input_shard_size = hlo->shape().dimensions(dim);
|
||||
int64 shard_count = target.tile_assignment().dim(dim);
|
||||
|
||||
std::vector<HloInstruction*> concat_pieces;
|
||||
|
||||
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
|
||||
if (max_left_halo_size > input_shard_size) {
|
||||
VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor.";
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (max_left_halo_size > 0) {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
target.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
if (indices[dim] > 0) {
|
||||
std::vector<int64> source_indices(indices.begin(), indices.end());
|
||||
source_indices[dim] -= 1;
|
||||
source_target_pairs.emplace_back(
|
||||
target.tile_assignment()(source_indices), device);
|
||||
}
|
||||
});
|
||||
auto halo_shape = hlo->shape();
|
||||
auto source_halo_slice = hlo;
|
||||
if (max_left_halo_size != hlo->shape().dimensions(dim)) {
|
||||
halo_shape.set_dimensions(dim, max_left_halo_size);
|
||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||
halo_start_indices[dim] =
|
||||
hlo->shape().dimensions(dim) - max_left_halo_size;
|
||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||
|
||||
source_halo_slice = b->AddInstruction(
|
||||
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
|
||||
hlo->shape().dimensions(), halo_slice_strides));
|
||||
}
|
||||
auto left_halo =
|
||||
collective_ops_creator.create_cross_partition_collective_permute(
|
||||
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
|
||||
concat_pieces.push_back(left_halo);
|
||||
}
|
||||
|
||||
concat_pieces.push_back(hlo);
|
||||
|
||||
// Right halo.
|
||||
int64 max_right_halo_size =
|
||||
right_halo_size_function.MaxInRange(0, shard_count - 1);
|
||||
if (max_right_halo_size > input_shard_size) {
|
||||
VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor.";
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (max_right_halo_size > 0) {
|
||||
std::vector<std::pair<int64, int64>> source_target_pairs;
|
||||
target.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
if (indices[dim] > 0) {
|
||||
std::vector<int64> target_indices(indices.begin(), indices.end());
|
||||
target_indices[dim] -= 1;
|
||||
source_target_pairs.emplace_back(
|
||||
device, target.tile_assignment()(target_indices));
|
||||
}
|
||||
});
|
||||
auto halo_shape = hlo->shape();
|
||||
halo_shape.set_dimensions(dim, max_right_halo_size);
|
||||
std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
|
||||
std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
|
||||
|
||||
auto source_halo_slice = b->AddInstruction(
|
||||
hlo->CreateSlice(halo_shape, hlo, halo_start_indices,
|
||||
halo_shape.dimensions(), halo_slice_strides));
|
||||
auto right_halo =
|
||||
collective_ops_creator.create_cross_partition_collective_permute(
|
||||
b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
|
||||
concat_pieces.push_back(right_halo);
|
||||
}
|
||||
|
||||
auto concat = hlo;
|
||||
// Concat with halos/padding.
|
||||
if (concat_pieces.size() > 1) {
|
||||
auto concat_shape = hlo->shape();
|
||||
int64 concat_dim_size = 0;
|
||||
for (auto piece : concat_pieces) {
|
||||
concat_dim_size += piece->shape().dimensions(dim);
|
||||
}
|
||||
concat_shape.set_dimensions(dim, concat_dim_size);
|
||||
concat = b->AddInstruction(
|
||||
HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
|
||||
}
|
||||
|
||||
return concat;
|
||||
}
|
||||
|
||||
absl::optional<HloInstruction*> ExchangeHalo(
|
||||
HloInstruction* hlo,
|
||||
std::vector<OffsetCalculation> left_halo_size_functions,
|
||||
std::vector<OffsetCalculation> right_halo_size_functions,
|
||||
const HloSharding& target,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b) {
|
||||
CHECK(left_halo_size_functions.size() == hlo->shape().rank());
|
||||
CHECK(right_halo_size_functions.size() == hlo->shape().rank());
|
||||
|
||||
HloInstruction* visiting_hlo = hlo;
|
||||
for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
|
||||
auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
|
||||
right_halo_size_functions[dim], dim, target,
|
||||
collective_ops_creator, next_channel_id, b);
|
||||
if (!concat) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
visiting_hlo = *concat;
|
||||
}
|
||||
return visiting_hlo;
|
||||
}
|
||||
|
||||
absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
|
||||
HloInstruction* hlo, const Shape& base_shape,
|
||||
const OffsetCalculation& left_halo_size_function,
|
||||
const OffsetCalculation& right_halo_size_function,
|
||||
int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
|
||||
int64 shard_size_with_halo, int64 dim, const HloSharding& target,
|
||||
HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
|
||||
HloInstruction* partition_ordinal,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
|
||||
auto halo_exchange_result =
|
||||
ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
|
||||
target, collective_ops_creator, next_channel_id, b);
|
||||
if (!halo_exchange_result) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
auto concat = *halo_exchange_result;
|
||||
int64 shard_count = target.tile_assignment().dim(dim);
|
||||
int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count);
|
||||
|
||||
// Now we determine if we need extra padding after the concat.
|
||||
//
|
||||
// The max of halo size or the first shard's explicit left padding.
|
||||
int64 max_left_halo_or_padding_size =
|
||||
std::max(std::max(int64{0}, max_left_halo_size),
|
||||
explicit_left_padding_on_full_shape);
|
||||
// The calculation that returns the dynamic slice index for a shard on the
|
||||
// padded concat, which is the difference between
|
||||
// max_left_halo_or_padding_size and its left halo size.
|
||||
auto start_offset_on_padded_concat_calculation =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
0, max_left_halo_or_padding_size, 1)) -
|
||||
left_halo_size_function;
|
||||
|
||||
// See if we need to pad the concat before dynamic slice.
|
||||
int64 extra_left_padding =
|
||||
std::max(int64{0}, max_left_halo_or_padding_size -
|
||||
std::max(int64{0}, max_left_halo_size));
|
||||
int64 extra_right_padding =
|
||||
start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
|
||||
shard_size_with_halo - concat->shape().dimensions(dim) -
|
||||
extra_left_padding;
|
||||
extra_right_padding = std::max(int64{0}, extra_right_padding);
|
||||
if (extra_left_padding > 0 || extra_right_padding > 0) {
|
||||
PaddingConfig padding_config;
|
||||
auto padded_concat_shape = concat->shape();
|
||||
for (int64 i = 0; i < base_shape.rank(); ++i) {
|
||||
auto padding_config_dim = padding_config.add_dimensions();
|
||||
padding_config_dim->set_interior_padding(0);
|
||||
padding_config_dim->set_edge_padding_low(0);
|
||||
padding_config_dim->set_edge_padding_high(0);
|
||||
if (i != dim) {
|
||||
continue;
|
||||
}
|
||||
padding_config_dim->set_edge_padding_low(extra_left_padding);
|
||||
padding_config_dim->set_edge_padding_high(extra_right_padding);
|
||||
padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
|
||||
extra_left_padding +
|
||||
extra_right_padding);
|
||||
}
|
||||
concat = b->AddInstruction(HloInstruction::CreatePad(
|
||||
padded_concat_shape, concat, pad_value, padding_config));
|
||||
}
|
||||
|
||||
auto valid_slice = concat;
|
||||
if (shard_size_with_halo != concat->shape().dimensions(dim)) {
|
||||
// Concat is bigger than the shard shape, so we need a dynamic slice.
|
||||
CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
|
||||
auto slice_shape = concat->shape();
|
||||
slice_shape.set_dimensions(dim, shard_size_with_halo);
|
||||
|
||||
if (left_halo_size_function.IsConstant() &&
|
||||
left_halo_size_function.Calculate(0) ==
|
||||
explicit_left_padding_on_full_shape) {
|
||||
std::vector<int64> start_indices(slice_shape.rank(), 0);
|
||||
std::vector<int64> strides(slice_shape.rank(), 1);
|
||||
valid_slice = b->AddInstruction(
|
||||
HloInstruction::CreateSlice(slice_shape, concat, start_indices,
|
||||
slice_shape.dimensions(), strides));
|
||||
} else {
|
||||
auto zero = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
|
||||
std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
|
||||
slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
|
||||
partition_ordinal, b);
|
||||
valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
slice_shape, concat, slice_offsets, slice_shape.dimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
if (!mask_invalid_region) {
|
||||
return valid_slice;
|
||||
}
|
||||
|
||||
int64 total_right_padding = padded_full_shape_size -
|
||||
base_shape.dimensions(dim) -
|
||||
explicit_left_padding_on_full_shape;
|
||||
// Mask off garbage data due to uneven partition or low/high padding.
|
||||
if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
|
||||
auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
|
||||
auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
|
||||
auto broadcast_start_index_in_padded_shape =
|
||||
b->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
index_shape, offset_on_padded_shape, {}));
|
||||
auto index_in_padded_shape = b->AddInstruction(
|
||||
HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
|
||||
broadcast_start_index_in_padded_shape));
|
||||
auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
|
||||
std::vector<HloInstruction*> predicates;
|
||||
if (explicit_left_padding_on_full_shape > 0) {
|
||||
auto valid_index_start =
|
||||
b->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
index_shape,
|
||||
b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
|
||||
explicit_left_padding_on_full_shape))),
|
||||
{}));
|
||||
predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
|
||||
mask_shape, index_in_padded_shape, valid_index_start,
|
||||
ComparisonDirection::kGe)));
|
||||
}
|
||||
if (total_right_padding > 0) {
|
||||
auto valid_index_limit =
|
||||
b->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
index_shape,
|
||||
b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
|
||||
base_shape.dimensions(dim) +
|
||||
explicit_left_padding_on_full_shape))),
|
||||
{}));
|
||||
predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
|
||||
mask_shape, index_in_padded_shape, valid_index_limit,
|
||||
ComparisonDirection::kLt)));
|
||||
}
|
||||
CHECK(!predicates.empty());
|
||||
auto is_valid =
|
||||
predicates.size() == 2
|
||||
? b->AddInstruction(HloInstruction::CreateBinary(
|
||||
mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
|
||||
: predicates[0];
|
||||
auto masking_value = b->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
|
||||
valid_slice = b->AddInstruction(
|
||||
HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
|
||||
is_valid, valid_slice, masking_value));
|
||||
}
|
||||
return valid_slice;
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
229
tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
Normal file
229
tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
Normal file
@ -0,0 +1,229 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
// Returns true if the given sharding contains any replicated sharding.
|
||||
bool HasReplicatedSharding(const HloSharding& sharding);
|
||||
|
||||
// Creates zero value instructions of the given shape.
|
||||
HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b);
|
||||
|
||||
template <typename NativeT>
|
||||
HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value,
|
||||
SpmdBuilder* b) {
|
||||
auto literal = LiteralUtil::CreateR0(value)
|
||||
.ConvertToShape(ShapeUtil::MakeShape(type, {}))
|
||||
.ValueOrDie();
|
||||
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
|
||||
}
|
||||
|
||||
// Create a binary add computation of the given type and add to the module.
|
||||
HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
|
||||
|
||||
// Returns true if the shape can be evenly partitioned for the given sharding.
|
||||
// All tile sharded dimensions should be evenly divisible and there should be no
|
||||
// single-device sharding. Replicate sharding is considered even partition.
|
||||
bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
|
||||
|
||||
// Returns the shard shape of the given shape when it is partitioned for the
|
||||
// target sharding.
|
||||
Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
|
||||
|
||||
// Returns the shard shape for a partition without padding due to uneven
|
||||
// sharding.
|
||||
Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
int64 partition_id);
|
||||
|
||||
// Generates the HLO instructions that represent the dimension offsets on any
|
||||
// device. The size of the returned vector is the rank of the given shape.
|
||||
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
HloInstruction* partition_id,
|
||||
SpmdBuilder* b);
|
||||
|
||||
// Returns the offsets of the partition in the tile assignment.
|
||||
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
|
||||
const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
|
||||
|
||||
// Pads hlo to the desired shape using high padding. Either a builder or a
|
||||
// computation needs to be supplied, but not both.
|
||||
HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
|
||||
SpmdBuilder* b,
|
||||
HloComputation* computation = nullptr);
|
||||
|
||||
// Returns the padded shape when combining all partitions.
|
||||
Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
|
||||
const HloSharding& sharding);
|
||||
|
||||
// Pads the HLO (with base shape) for uneven tiled partition to make it evenly
|
||||
// partitionable.
|
||||
HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
|
||||
HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
|
||||
|
||||
// Returns the index of the unique tile dimension. Returns absl::nullopt if the
|
||||
// given sharding is not tiled or tiled along multiple dimensions.
|
||||
absl::optional<int64> UniqueTiledDim(const HloSharding& sharding);
|
||||
|
||||
// Utilities for symbolic offset calculation and halo exchange.
|
||||
class OffsetCalculation;
|
||||
|
||||
// Represents a calculation over integers:
|
||||
// (shard_ordinal * multiplier + offset) / divisor
|
||||
class MultiplyAddDivideOffsetCalculation {
|
||||
public:
|
||||
MultiplyAddDivideOffsetCalculation()
|
||||
: multiplier_(0), offset_(0), divisor_(1) {}
|
||||
MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset,
|
||||
int64 divisor);
|
||||
|
||||
OffsetCalculation operator-(
|
||||
const MultiplyAddDivideOffsetCalculation& other) const;
|
||||
|
||||
bool operator==(const MultiplyAddDivideOffsetCalculation& other) const {
|
||||
return multiplier_ == other.multiplier_ && offset_ == other.offset_ &&
|
||||
divisor_ == other.divisor_;
|
||||
}
|
||||
|
||||
bool IsConstant() const { return multiplier_ == 0; }
|
||||
void Simplify();
|
||||
int64 Calculate(int64 shard_ordinal) const;
|
||||
HloInstruction* Calculate(HloInstruction* shard_ordinal,
|
||||
SpmdBuilder* b) const;
|
||||
|
||||
// Returns the maximum result for shard ordinals in the range
|
||||
// [start_ordinal, limit_ordinal).
|
||||
int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
|
||||
|
||||
private:
|
||||
int64 multiplier_;
|
||||
int64 offset_;
|
||||
int64 divisor_;
|
||||
};
|
||||
|
||||
// Represents a calculation over integers based on results of other calculations
|
||||
// defined by an opcode. If the opcode is kCopy, it simply wraps an
|
||||
// MultiplyAddDivideOffsetCalculation.
|
||||
class OffsetCalculation {
|
||||
public:
|
||||
OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {}
|
||||
explicit OffsetCalculation(
|
||||
const MultiplyAddDivideOffsetCalculation& copy_from)
|
||||
: opcode_(HloOpcode::kCopy), copy_from_(copy_from) {}
|
||||
OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; }
|
||||
OffsetCalculation(HloOpcode opcode,
|
||||
const MultiplyAddDivideOffsetCalculation& lhs,
|
||||
const MultiplyAddDivideOffsetCalculation& rhs)
|
||||
: opcode_(opcode),
|
||||
lhs_(absl::make_unique<OffsetCalculation>(lhs)),
|
||||
rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
|
||||
OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs,
|
||||
const OffsetCalculation& rhs)
|
||||
: opcode_(opcode),
|
||||
lhs_(absl::make_unique<OffsetCalculation>(lhs)),
|
||||
rhs_(absl::make_unique<OffsetCalculation>(rhs)) {}
|
||||
|
||||
OffsetCalculation& operator=(const OffsetCalculation& other);
|
||||
|
||||
// Returns whether the calculation returns the same value for all shards. This
|
||||
// is conservative and could return false even if it is actually constant.
|
||||
bool IsConstant() const;
|
||||
|
||||
OffsetCalculation operator-(const OffsetCalculation& other) const;
|
||||
bool operator==(const OffsetCalculation& other) const;
|
||||
int64 Calculate(int64 shard_ordinal) const;
|
||||
HloInstruction* Calculate(HloInstruction* shard_ordinal,
|
||||
SpmdBuilder* b) const;
|
||||
|
||||
// Returns the maximum result for shard ordinals in the range
|
||||
// [start_ordinal, limit_ordinal).
|
||||
int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const;
|
||||
|
||||
private:
|
||||
HloOpcode opcode_;
|
||||
std::unique_ptr<OffsetCalculation> lhs_;
|
||||
std::unique_ptr<OffsetCalculation> rhs_;
|
||||
MultiplyAddDivideOffsetCalculation copy_from_;
|
||||
};
|
||||
|
||||
// Performs halo exchange on the given dimension based on the provided
|
||||
// left/right halo size functions. Returns nullopt if the halo is beyond the
|
||||
// direct neighbor of the shard.
|
||||
absl::optional<HloInstruction*> ExchangeHalo(
|
||||
HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
|
||||
const OffsetCalculation& right_halo_size_function, int64 dim,
|
||||
const HloSharding& target,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b);
|
||||
|
||||
// Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
|
||||
// dimensions fails to exchange halo (halo is beyond the neighbor shard).
|
||||
absl::optional<HloInstruction*> ExchangeHalo(
|
||||
HloInstruction* hlo,
|
||||
std::vector<OffsetCalculation> left_halo_size_functions,
|
||||
std::vector<OffsetCalculation> right_halo_size_functions,
|
||||
const HloSharding& target,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b);
|
||||
|
||||
// Exchanges halos and performs pad/dynamic-slice on the concatenated data such
|
||||
// that the result starts with the first needed element on each shard. It also
|
||||
// masks off invalid data due to padding.
|
||||
// Arguments:
|
||||
// hlo: the HLO op before halo exchange
|
||||
// explicit_left_padding_on_full_shape: the amount of left padding to be added
|
||||
// explicitly by this function on the base shape before partitioning. Without
|
||||
// base dilation, this is usually set to the window's padding_low so that the
|
||||
// sharded op do not need to add padding_low on the window; however, with base
|
||||
// dilation, this could only be set to a custom size.
|
||||
// padded_full_shape_size: the size of the padded full shape on the given
|
||||
// dimension, which includes explicit_left_padding_on_full_shape and required
|
||||
// right padding to make the shape evenly shardable.
|
||||
// shard_size_with_halo: the shard size on the dimension after halo exchange.
|
||||
// If different shards have different sizes, use the maximum size.
|
||||
// offset_on_padded_shape: the offset HLO (S32) that represents the start of
|
||||
// each shard on the padded full shape.
|
||||
// pad_value: the padding value used on the full shape.
|
||||
absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
|
||||
HloInstruction* hlo, const Shape& base_shape,
|
||||
const OffsetCalculation& left_halo_size_function,
|
||||
const OffsetCalculation& right_halo_size_function,
|
||||
int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size,
|
||||
int64 shard_size_with_halo, int64 dim, const HloSharding& target,
|
||||
HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
|
||||
HloInstruction* partition_ordinal,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
|
||||
Loading…
x
Reference in New Issue
Block a user