A. Unique TensorFlower c587b9a2dc [XLA:SPMD] Add basic support for SPMD FFT.
PiperOrigin-RevId: 338293995
Change-Id: I05ee210f469d4f1609c5ecaf8de6eaa347f7419a
2020-10-21 11:19:54 -07:00

437 lines
20 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <float.h>
#include <cmath>
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace spmd {
namespace {
// Pad each partition to have size that is multiplication of num_partitions.
// For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2,
// after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in
// partition 1.
absl::optional<HloInstruction*> PadEachPartitionWithHaloExchange(
HloInstruction* hlo, int64 num_partitions, const HloSharding& sharding,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
int64 size_per_partition = hlo->shape().dimensions().back();
int64 size_padded_per_partition =
CeilOfRatio(size_per_partition, num_partitions) * num_partitions;
if (size_per_partition == size_padded_per_partition) {
return hlo;
}
// 1. Calculate left_halo size.
// left-halo size is 0
OffsetCalculation left_halo_size_function =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
// 2. Calculate right_halo size.
// D = size_padded_per_partition
// S = size_per_partition
// i = shard_ordinal
// right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S)
OffsetCalculation right_halo_size_function =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
size_padded_per_partition - size_per_partition,
2 * (size_padded_per_partition - size_per_partition), 1));
auto concat = hlo;
// 3. Halo exchange.
auto halo_exchange_result =
ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function,
hlo->shape().rank() - 1, sharding, collective_ops_creator,
next_channel_id, b);
if (halo_exchange_result.has_value()) {
concat = halo_exchange_result.value();
} else {
return absl::nullopt;
}
// 4. Slice the valid result.
// Slice offset is (D - S) * i
OffsetCalculation start_offset_on_padded_concat_calculation =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
size_padded_per_partition - size_per_partition, 0, 1));
auto slice_shape = concat->shape();
slice_shape.set_dimensions(concat->shape().rank() - 1,
size_padded_per_partition);
auto zero_s32 =
b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32);
auto partition_ordinals =
MakeTiledPartitionOrdinals(sharding, partition_id, b);
slice_offsets[concat->shape().rank() - 1] =
start_offset_on_padded_concat_calculation.Calculate(
partition_ordinals[concat->shape().rank() - 1], b);
return b->AddInstruction(HloInstruction::CreateDynamicSlice(
slice_shape, concat, slice_offsets, slice_shape.dimensions()));
}
// If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
// the data becomes {0, 2, 1, 3}.
HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo,
int64 num_partitions,
SpmdBuilder* b) {
int64 size_per_partition = hlo->shape().dimensions().back();
CHECK_EQ(size_per_partition % num_partitions, 0);
auto indices_iota = b->AddInstruction(HloInstruction::CreateIota(
ShapeUtil::MakeShape(S32, {size_per_partition}), 0));
auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(
S32, {size_per_partition / num_partitions, num_partitions}),
indices_iota));
auto transpoe_indices_iota =
b->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(
S32, {num_partitions, size_per_partition / num_partitions}),
reshape_indices_iota, {1, 0}));
auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}),
b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {size_per_partition}),
transpoe_indices_iota)),
/*broadcast_dimensions=*/{1}));
auto partition_indices = b->AddInstruction(HloInstruction::CreateIota(
ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0));
auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(partition_indices->shape(),
hlo->shape().element_type()),
b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(partition_indices->shape(), PRED),
one_hot_indices, partition_indices, ComparisonDirection::kEq))));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1);
dot_dnums.add_rhs_contracting_dimensions(0);
PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
2, PrecisionConfig::DEFAULT);
HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot(
hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config));
return dot;
}
// If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
// num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
// and partition 1 will have {1, 3, 5, 0}.
HloInstruction* ShuffleDataWithAllToAll(
HloInstruction* hlo, int64 num_partitions,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdBuilder* b) {
std::vector<std::vector<int64>> groups(1);
std::vector<int64> partition_subgroups(num_partitions);
std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0);
groups[0] = partition_subgroups;
auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all(
b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1);
return all_to_all;
}
HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64 num_partitions,
HloInstruction* partition_id,
SpmdBuilder* b) {
/* n = size_per_replica
m = num_partitions
factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) *
* tf.cast(tf.range(n), dtype=tf.complex64) /
(n * m))
*/
auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
return b->AddInstruction(std::move(to_add));
};
int64 per_replica_size = hlo->shape().dimensions().back();
auto constant_factor =
add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0(
complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size)))));
constant_factor = add_hlo(HloInstruction::CreateBroadcast(
hlo->shape(), constant_factor, /*broadcast_dimensions=*/{}));
auto converted_partition_id = add_hlo(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(partition_id->shape(),
hlo->shape().element_type()),
partition_id));
// TODO(wangtao): multipy before broadcast.
auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast(
hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{}));
auto exp_operand = add_hlo(
HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply,
constant_factor, broadcast_partition_id));
auto iota = add_hlo(
HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1));
exp_operand = add_hlo(HloInstruction::CreateBinary(
hlo->shape(), HloOpcode::kMultiply, exp_operand, iota));
return add_hlo(
HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand));
}
// Sudo code for the while loop:
// def body(dest_transform, dest_core_position, source_transform,
// source_core_position, i):
// factor = tf.exp(-2.0j * np.pi *
// tf.cast(dest_core_position, tf.complex64) *
// tf.cast(source_core_position, tf.complex64) / num_partitions)
// dest_transform += factor * source_transform
// source_core_position = tf.raw_ops.CollectivePermute(
// input=source_core_position,
// source_target_pairs=source_target_pairs,
// name='source_core_position_permute')
// source_transform = tf.raw_ops.CollectivePermute(
// input=source_transform,
// source_target_pairs=source_target_pairs,
// name='source_transform_permute')
// i += 1
// return (dest_transform, dest_core_position, source_transform,
// source_core_position, i)
HloInstruction* GetFinalFftUsingCollectivePermute(
HloInstruction* hlo, const HloSharding& sharding,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64 num_partitions, HloInstruction* partition_id, int64* next_channel_id,
HloModule* module, SpmdBuilder* b) {
auto iteration = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(partition_id->shape(),
hlo->shape().element_type()),
partition_id));
// Buid while loop body.
SpmdBuilder body_b("fft_collective_permute_body", hlo);
auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape(
{hlo->shape(), hlo->shape(), converted_partition_id->shape(),
converted_partition_id->shape(), iteration->shape()}),
"param"));
auto dest_transform = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0));
auto source_transform = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1));
auto dest_partition_id =
body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
converted_partition_id->shape(), param, 2));
auto source_partition_id =
body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
converted_partition_id->shape(), param, 3));
auto i = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
/*
factor = tf.exp(-2.0j * np.pi *
tf.cast(dest_partiton_id, tf.complex64) *
tf.cast(source_partition_id, tf.complex64) /
num_partitions) dest_transform += factor * source_transform
*/
auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions))));
constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
dest_partition_id));
constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
source_partition_id));
auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary(
constant_factor->shape(), HloOpcode::kExp, constant_factor));
phase_factor = body_b.AddInstruction(
HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {}));
auto phase_adjust_source_transform =
body_b.AddInstruction(HloInstruction::CreateBinary(
hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform));
dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary(
hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform,
dest_transform));
// collective permute for source partition_id and source_transfrom.
std::vector<std::pair<int64, int64>> src_dst_pairs;
sharding.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 src_device) {
std::vector<int64> target_indices(indices.begin(), indices.end());
target_indices.back() = (indices.back() + 1) % num_partitions;
int64 dst_device = sharding.tile_assignment()(target_indices);
src_dst_pairs.emplace_back(src_device, dst_device);
});
source_partition_id =
collective_ops_creator.create_cross_partition_collective_permute(
&body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++);
source_transform =
collective_ops_creator.create_cross_partition_collective_permute(
&body_b, source_transform, src_dst_pairs, (*next_channel_id)++);
// ++i
i = body_b.AddInstruction(HloInstruction::CreateBinary(
i->shape(), HloOpcode::kAdd, i,
body_b.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)))));
body_b.AddInstruction(
HloInstruction::CreateTuple({dest_transform, source_transform,
dest_partition_id, source_partition_id, i}));
// Build while loop conditions.
auto zero = CreateZero(hlo->shape(), b);
SpmdBuilder cond_b("fft_collective_permute_condition", hlo);
auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape(
{hlo->shape(), hlo->shape(), converted_partition_id->shape(),
converted_partition_id->shape(), iteration->shape()}),
"param"));
auto cond_i = cond_b.AddInstruction(
HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4));
cond_b.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), cond_i,
cond_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions))),
ComparisonDirection::kLt));
// Build while loop.
auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
module->AddEmbeddedComputation(body_b.Build()),
b->AddInstruction(
HloInstruction::CreateTuple({zero, hlo, converted_partition_id,
converted_partition_id, iteration}))));
return b->AddInstruction(
HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0));
}
// Slice valid data in each partition.
HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape,
SpmdBuilder* b) {
std::vector<int64> start_indices(target_shape.rank(), 0);
std::vector<int64> strides(target_shape.rank(), 1);
return b->AddInstruction(HloInstruction::CreateSlice(
target_shape, hlo, start_indices, target_shape.dimensions(), strides));
}
} // namespace
// Distributed FFT using the algorithm described in go/tpu-spmd-fft.
Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {
return DefaultAction(hlo);
}
// Only support input_length equals fft_length's case.
int64 input_length = hlo->operand(0)->shape().dimensions().back();
int64 fft_length = hlo->fft_length().back();
if (input_length != fft_length || input_length % num_partitions_ != 0) {
return DefaultAction(hlo);
}
// Support partition at the last dimension only.
if (!hlo->has_sharding() ||
hlo->sharding().tile_assignment().dimensions().back() !=
num_partitions_) {
return DefaultAction(hlo);
}
auto partitioned_input =
GetPartitionedHlo(hlo->operand(0))
.PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_));
// 1.a. Use right halo exchange to shuffle data first and slice with
// valid data. Data shuffling ensures an in-order transform that the sequences
// of data before and after the transform are the same. The data shuffling
// requires the size of data per partition is divisible by the number of
// partitions. For example, If input is {0, 1, 2, 3, 4, 5} and
// num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and
// partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data.
// Zeros paddings append zeros to the end of the full data.
auto result = partitioned_input.hlo();
auto padded_hlo = PadEachPartitionWithHaloExchange(
partitioned_input.hlo(), num_partitions_, hlo->sharding(),
partitioned_input.state().collective_ops_creator,
partitioned_input.state().next_channel_id,
partitioned_input.state().partition_id, partitioned_input.state().b);
if (padded_hlo.has_value()) {
result = padded_hlo.value();
}
// 1.b Shuffle data within each partition using one hot and matmul.
// If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
// the data becomes {0, 2, 1, 3}.
result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_,
partitioned_input.state().b);
// 1.c all-to-all
// If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
// num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
// and partition 1 will have {1, 3, 5, 0}.
result = ShuffleDataWithAllToAll(
result, num_partitions_, partitioned_input.state().collective_ops_creator,
partitioned_input.state().next_channel_id, partitioned_input.state().b);
// 1.d Slice valid data in each partition.
result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_);
// 2. Do local fft transform.
auto partitioned_fft_length = hlo->fft_length();
partitioned_fft_length.back() /= num_partitions_;
result = b_.AddInstruction(HloInstruction::CreateFft(
result->shape(), result, hlo->fft_type(), partitioned_fft_length));
// Multiply by correct factor for local phase ajustment.
auto correction_factor = GetCorrectionFactor(
result, num_partitions_, partitioned_input.state().partition_id,
partitioned_input.state().b);
result = b_.AddInstruction(HloInstruction::CreateBinary(
result->shape(), HloOpcode::kMultiply, result, correction_factor));
// 3. Second phase FFT with collective permute. fft_length = num_partitions.
result = GetFinalFftUsingCollectivePermute(
result, hlo->sharding(), partitioned_input.state().collective_ops_creator,
num_partitions_, partitioned_input.state().partition_id,
partitioned_input.state().next_channel_id, module_,
partitioned_input.state().b);
result->set_sharding(hlo->sharding());
auto partitioned_fft =
PartitionedHlo(result, hlo->shape(), partitioned_input.state());
SetPartitionedHlo(hlo, partitioned_fft);
return Status::OK();
}
} // namespace spmd
} // namespace xla