[XLA:SPMD] Add basic support for SPMD FFT.
PiperOrigin-RevId: 338293995 Change-Id: I05ee210f469d4f1609c5ecaf8de6eaa347f7419a
This commit is contained in:
parent
54bc354d2e
commit
c587b9a2dc
@ -20,6 +20,7 @@ cc_library(
|
||||
srcs = [
|
||||
"convolution_handler.cc",
|
||||
"dot_handler.cc",
|
||||
"fft_handler.cc",
|
||||
"spmd_partitioner.cc",
|
||||
"spmd_partitioner_util.cc",
|
||||
],
|
||||
|
||||
436
tensorflow/compiler/xla/service/spmd/fft_handler.cc
Normal file
436
tensorflow/compiler/xla/service/spmd/fft_handler.cc
Normal file
@ -0,0 +1,436 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
namespace {
|
||||
|
||||
// Pad each partition to have size that is multiplication of num_partitions.
|
||||
// For example, if input is {0, 1, 2, 3, 4, 5} and num_partitions = 2,
|
||||
// after padding, it becomes {0, 1, 2, 3} in partition 0 and {4, 5, 0, 0} in
|
||||
// partition 1.
|
||||
absl::optional<HloInstruction*> PadEachPartitionWithHaloExchange(
|
||||
HloInstruction* hlo, int64 num_partitions, const HloSharding& sharding,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
|
||||
int64 size_per_partition = hlo->shape().dimensions().back();
|
||||
int64 size_padded_per_partition =
|
||||
CeilOfRatio(size_per_partition, num_partitions) * num_partitions;
|
||||
if (size_per_partition == size_padded_per_partition) {
|
||||
return hlo;
|
||||
}
|
||||
// 1. Calculate left_halo size.
|
||||
// left-halo size is 0
|
||||
OffsetCalculation left_halo_size_function =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
|
||||
|
||||
// 2. Calculate right_halo size.
|
||||
// D = size_padded_per_partition
|
||||
// S = size_per_partition
|
||||
// i = shard_ordinal
|
||||
// right-halo size is D * (i + 2) - S * (i + 2) = (D - S) * i + 2 * (D - S)
|
||||
OffsetCalculation right_halo_size_function =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
size_padded_per_partition - size_per_partition,
|
||||
2 * (size_padded_per_partition - size_per_partition), 1));
|
||||
|
||||
auto concat = hlo;
|
||||
// 3. Halo exchange.
|
||||
auto halo_exchange_result =
|
||||
ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function,
|
||||
hlo->shape().rank() - 1, sharding, collective_ops_creator,
|
||||
next_channel_id, b);
|
||||
|
||||
if (halo_exchange_result.has_value()) {
|
||||
concat = halo_exchange_result.value();
|
||||
} else {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// 4. Slice the valid result.
|
||||
// Slice offset is (D - S) * i
|
||||
OffsetCalculation start_offset_on_padded_concat_calculation =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
size_padded_per_partition - size_per_partition, 0, 1));
|
||||
auto slice_shape = concat->shape();
|
||||
slice_shape.set_dimensions(concat->shape().rank() - 1,
|
||||
size_padded_per_partition);
|
||||
auto zero_s32 =
|
||||
b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
|
||||
std::vector<HloInstruction*> slice_offsets(concat->shape().rank(), zero_s32);
|
||||
auto partition_ordinals =
|
||||
MakeTiledPartitionOrdinals(sharding, partition_id, b);
|
||||
slice_offsets[concat->shape().rank() - 1] =
|
||||
start_offset_on_padded_concat_calculation.Calculate(
|
||||
partition_ordinals[concat->shape().rank() - 1], b);
|
||||
return b->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
slice_shape, concat, slice_offsets, slice_shape.dimensions()));
|
||||
}
|
||||
|
||||
// If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
|
||||
// the data becomes {0, 2, 1, 3}.
|
||||
HloInstruction* ShuffleWithinEachPartitionUsingOneHot(HloInstruction* hlo,
|
||||
int64 num_partitions,
|
||||
SpmdBuilder* b) {
|
||||
int64 size_per_partition = hlo->shape().dimensions().back();
|
||||
CHECK_EQ(size_per_partition % num_partitions, 0);
|
||||
auto indices_iota = b->AddInstruction(HloInstruction::CreateIota(
|
||||
ShapeUtil::MakeShape(S32, {size_per_partition}), 0));
|
||||
auto reshape_indices_iota = b->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(
|
||||
S32, {size_per_partition / num_partitions, num_partitions}),
|
||||
indices_iota));
|
||||
auto transpoe_indices_iota =
|
||||
b->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(
|
||||
S32, {num_partitions, size_per_partition / num_partitions}),
|
||||
reshape_indices_iota, {1, 0}));
|
||||
auto one_hot_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}),
|
||||
b->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(S32, {size_per_partition}),
|
||||
transpoe_indices_iota)),
|
||||
/*broadcast_dimensions=*/{1}));
|
||||
|
||||
auto partition_indices = b->AddInstruction(HloInstruction::CreateIota(
|
||||
ShapeUtil::MakeShape(S32, {size_per_partition, size_per_partition}), 0));
|
||||
|
||||
auto shuffle_one_hot = b->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(partition_indices->shape(),
|
||||
hlo->shape().element_type()),
|
||||
b->AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::ChangeElementType(partition_indices->shape(), PRED),
|
||||
one_hot_indices, partition_indices, ComparisonDirection::kEq))));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
PrecisionConfig precision_config;
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
2, PrecisionConfig::DEFAULT);
|
||||
HloInstruction* dot = b->AddInstruction(HloInstruction::CreateDot(
|
||||
hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config));
|
||||
return dot;
|
||||
}
|
||||
|
||||
// If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
|
||||
// num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
|
||||
// and partition 1 will have {1, 3, 5, 0}.
|
||||
HloInstruction* ShuffleDataWithAllToAll(
|
||||
HloInstruction* hlo, int64 num_partitions,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b) {
|
||||
std::vector<std::vector<int64>> groups(1);
|
||||
std::vector<int64> partition_subgroups(num_partitions);
|
||||
std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0);
|
||||
groups[0] = partition_subgroups;
|
||||
auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all(
|
||||
b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1);
|
||||
return all_to_all;
|
||||
}
|
||||
|
||||
HloInstruction* GetCorrectionFactor(HloInstruction* hlo, int64 num_partitions,
|
||||
HloInstruction* partition_id,
|
||||
SpmdBuilder* b) {
|
||||
/* n = size_per_replica
|
||||
m = num_partitions
|
||||
factor = tf.exp(-2.0j * np.pi * tf.cast(position_index, tf.complex64) *
|
||||
* tf.cast(tf.range(n), dtype=tf.complex64) /
|
||||
(n * m))
|
||||
|
||||
*/
|
||||
auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
|
||||
return b->AddInstruction(std::move(to_add));
|
||||
};
|
||||
int64 per_replica_size = hlo->shape().dimensions().back();
|
||||
auto constant_factor =
|
||||
add_hlo(HloInstruction::CreateConstant(LiteralUtil::CreateR0(
|
||||
complex64(0, -2.0 * M_PI / (num_partitions * per_replica_size)))));
|
||||
constant_factor = add_hlo(HloInstruction::CreateBroadcast(
|
||||
hlo->shape(), constant_factor, /*broadcast_dimensions=*/{}));
|
||||
auto converted_partition_id = add_hlo(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(partition_id->shape(),
|
||||
hlo->shape().element_type()),
|
||||
partition_id));
|
||||
// TODO(wangtao): multipy before broadcast.
|
||||
auto broadcast_partition_id = add_hlo(HloInstruction::CreateBroadcast(
|
||||
hlo->shape(), converted_partition_id, /*broadcast_dimensions=*/{}));
|
||||
auto exp_operand = add_hlo(
|
||||
HloInstruction::CreateBinary(hlo->shape(), HloOpcode::kMultiply,
|
||||
constant_factor, broadcast_partition_id));
|
||||
auto iota = add_hlo(
|
||||
HloInstruction::CreateIota(hlo->shape(), hlo->shape().rank() - 1));
|
||||
exp_operand = add_hlo(HloInstruction::CreateBinary(
|
||||
hlo->shape(), HloOpcode::kMultiply, exp_operand, iota));
|
||||
return add_hlo(
|
||||
HloInstruction::CreateUnary(hlo->shape(), HloOpcode::kExp, exp_operand));
|
||||
}
|
||||
|
||||
// Sudo code for the while loop:
|
||||
// def body(dest_transform, dest_core_position, source_transform,
|
||||
// source_core_position, i):
|
||||
// factor = tf.exp(-2.0j * np.pi *
|
||||
// tf.cast(dest_core_position, tf.complex64) *
|
||||
// tf.cast(source_core_position, tf.complex64) / num_partitions)
|
||||
// dest_transform += factor * source_transform
|
||||
// source_core_position = tf.raw_ops.CollectivePermute(
|
||||
// input=source_core_position,
|
||||
// source_target_pairs=source_target_pairs,
|
||||
// name='source_core_position_permute')
|
||||
// source_transform = tf.raw_ops.CollectivePermute(
|
||||
// input=source_transform,
|
||||
// source_target_pairs=source_target_pairs,
|
||||
// name='source_transform_permute')
|
||||
// i += 1
|
||||
// return (dest_transform, dest_core_position, source_transform,
|
||||
// source_core_position, i)
|
||||
HloInstruction* GetFinalFftUsingCollectivePermute(
|
||||
HloInstruction* hlo, const HloSharding& sharding,
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64 num_partitions, HloInstruction* partition_id, int64* next_channel_id,
|
||||
HloModule* module, SpmdBuilder* b) {
|
||||
auto iteration = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
|
||||
auto converted_partition_id = b->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(partition_id->shape(),
|
||||
hlo->shape().element_type()),
|
||||
partition_id));
|
||||
// Buid while loop body.
|
||||
SpmdBuilder body_b("fft_collective_permute_body", hlo);
|
||||
auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{hlo->shape(), hlo->shape(), converted_partition_id->shape(),
|
||||
converted_partition_id->shape(), iteration->shape()}),
|
||||
"param"));
|
||||
auto dest_transform = body_b.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(hlo->shape(), param, 0));
|
||||
auto source_transform = body_b.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(hlo->shape(), param, 1));
|
||||
auto dest_partition_id =
|
||||
body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
converted_partition_id->shape(), param, 2));
|
||||
auto source_partition_id =
|
||||
body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
converted_partition_id->shape(), param, 3));
|
||||
auto i = body_b.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
|
||||
/*
|
||||
factor = tf.exp(-2.0j * np.pi *
|
||||
tf.cast(dest_partiton_id, tf.complex64) *
|
||||
tf.cast(source_partition_id, tf.complex64) /
|
||||
num_partitions) dest_transform += factor * source_transform
|
||||
*/
|
||||
auto constant_factor = body_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0(complex64(0, -2.0 * M_PI / num_partitions))));
|
||||
|
||||
constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
|
||||
dest_partition_id));
|
||||
constant_factor = body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
constant_factor->shape(), HloOpcode::kMultiply, constant_factor,
|
||||
source_partition_id));
|
||||
auto phase_factor = body_b.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant_factor->shape(), HloOpcode::kExp, constant_factor));
|
||||
phase_factor = body_b.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(hlo->shape(), phase_factor, {}));
|
||||
auto phase_adjust_source_transform =
|
||||
body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
hlo->shape(), HloOpcode::kMultiply, phase_factor, source_transform));
|
||||
dest_transform = body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
hlo->shape(), HloOpcode::kAdd, phase_adjust_source_transform,
|
||||
dest_transform));
|
||||
// collective permute for source partition_id and source_transfrom.
|
||||
std::vector<std::pair<int64, int64>> src_dst_pairs;
|
||||
sharding.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 src_device) {
|
||||
std::vector<int64> target_indices(indices.begin(), indices.end());
|
||||
target_indices.back() = (indices.back() + 1) % num_partitions;
|
||||
int64 dst_device = sharding.tile_assignment()(target_indices);
|
||||
src_dst_pairs.emplace_back(src_device, dst_device);
|
||||
});
|
||||
|
||||
source_partition_id =
|
||||
collective_ops_creator.create_cross_partition_collective_permute(
|
||||
&body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++);
|
||||
|
||||
source_transform =
|
||||
collective_ops_creator.create_cross_partition_collective_permute(
|
||||
&body_b, source_transform, src_dst_pairs, (*next_channel_id)++);
|
||||
|
||||
// ++i
|
||||
i = body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
i->shape(), HloOpcode::kAdd, i,
|
||||
body_b.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)))));
|
||||
body_b.AddInstruction(
|
||||
HloInstruction::CreateTuple({dest_transform, source_transform,
|
||||
dest_partition_id, source_partition_id, i}));
|
||||
|
||||
// Build while loop conditions.
|
||||
auto zero = CreateZero(hlo->shape(), b);
|
||||
SpmdBuilder cond_b("fft_collective_permute_condition", hlo);
|
||||
auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{hlo->shape(), hlo->shape(), converted_partition_id->shape(),
|
||||
converted_partition_id->shape(), iteration->shape()}),
|
||||
"param"));
|
||||
auto cond_i = cond_b.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(iteration->shape(), cond_param, 4));
|
||||
cond_b.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), cond_i,
|
||||
cond_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions))),
|
||||
ComparisonDirection::kLt));
|
||||
|
||||
// Build while loop.
|
||||
auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
|
||||
cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
|
||||
module->AddEmbeddedComputation(body_b.Build()),
|
||||
b->AddInstruction(
|
||||
HloInstruction::CreateTuple({zero, hlo, converted_partition_id,
|
||||
converted_partition_id, iteration}))));
|
||||
|
||||
return b->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(hlo->shape(), while_loop, 0));
|
||||
}
|
||||
|
||||
// Slice valid data in each partition.
|
||||
HloInstruction* SliceValidData(HloInstruction* hlo, const Shape& target_shape,
|
||||
SpmdBuilder* b) {
|
||||
std::vector<int64> start_indices(target_shape.rank(), 0);
|
||||
std::vector<int64> strides(target_shape.rank(), 1);
|
||||
return b->AddInstruction(HloInstruction::CreateSlice(
|
||||
target_shape, hlo, start_indices, target_shape.dimensions(), strides));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Distributed FFT using the algorithm described in go/tpu-spmd-fft.
|
||||
Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
|
||||
if (hlo->operand(0)->shape().rank() < 3 || hlo->fft_type() != FftType::FFT) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
// Only support input_length equals fft_length's case.
|
||||
int64 input_length = hlo->operand(0)->shape().dimensions().back();
|
||||
int64 fft_length = hlo->fft_length().back();
|
||||
if (input_length != fft_length || input_length % num_partitions_ != 0) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
// Support partition at the last dimension only.
|
||||
if (!hlo->has_sharding() ||
|
||||
hlo->sharding().tile_assignment().dimensions().back() !=
|
||||
num_partitions_) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
auto partitioned_input =
|
||||
GetPartitionedHlo(hlo->operand(0))
|
||||
.PadWithValue(CreateR0WithType(hlo->shape().element_type(), 0, &b_));
|
||||
|
||||
// 1.a. Use right halo exchange to shuffle data first and slice with
|
||||
// valid data. Data shuffling ensures an in-order transform that the sequences
|
||||
// of data before and after the transform are the same. The data shuffling
|
||||
// requires the size of data per partition is divisible by the number of
|
||||
// partitions. For example, If input is {0, 1, 2, 3, 4, 5} and
|
||||
// num partitions is 2, after halo exchange partition 0 has {0, 1, 2, 3} and
|
||||
// partition 1 has {4, 5, 0, 0}, where 0s in the partition 1 are padding data.
|
||||
// Zeros paddings append zeros to the end of the full data.
|
||||
auto result = partitioned_input.hlo();
|
||||
auto padded_hlo = PadEachPartitionWithHaloExchange(
|
||||
partitioned_input.hlo(), num_partitions_, hlo->sharding(),
|
||||
partitioned_input.state().collective_ops_creator,
|
||||
partitioned_input.state().next_channel_id,
|
||||
partitioned_input.state().partition_id, partitioned_input.state().b);
|
||||
|
||||
if (padded_hlo.has_value()) {
|
||||
result = padded_hlo.value();
|
||||
}
|
||||
|
||||
// 1.b Shuffle data within each partition using one hot and matmul.
|
||||
// If partition 0 has {0, 1, 2, 3} and num partitions is 2, after shuffling,
|
||||
// the data becomes {0, 2, 1, 3}.
|
||||
result = ShuffleWithinEachPartitionUsingOneHot(result, num_partitions_,
|
||||
partitioned_input.state().b);
|
||||
// 1.c all-to-all
|
||||
// If partition 0 has {0, 2, 1, 3}, partition 1 has {4, 0, 5, 0} and
|
||||
// num partitions is 2, after all-to-all, partition 0 will have {0, 2, 4, 0}
|
||||
// and partition 1 will have {1, 3, 5, 0}.
|
||||
result = ShuffleDataWithAllToAll(
|
||||
result, num_partitions_, partitioned_input.state().collective_ops_creator,
|
||||
partitioned_input.state().next_channel_id, partitioned_input.state().b);
|
||||
// 1.d Slice valid data in each partition.
|
||||
result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_);
|
||||
|
||||
// 2. Do local fft transform.
|
||||
auto partitioned_fft_length = hlo->fft_length();
|
||||
partitioned_fft_length.back() /= num_partitions_;
|
||||
result = b_.AddInstruction(HloInstruction::CreateFft(
|
||||
result->shape(), result, hlo->fft_type(), partitioned_fft_length));
|
||||
|
||||
// Multiply by correct factor for local phase ajustment.
|
||||
auto correction_factor = GetCorrectionFactor(
|
||||
result, num_partitions_, partitioned_input.state().partition_id,
|
||||
partitioned_input.state().b);
|
||||
result = b_.AddInstruction(HloInstruction::CreateBinary(
|
||||
result->shape(), HloOpcode::kMultiply, result, correction_factor));
|
||||
|
||||
// 3. Second phase FFT with collective permute. fft_length = num_partitions.
|
||||
result = GetFinalFftUsingCollectivePermute(
|
||||
result, hlo->sharding(), partitioned_input.state().collective_ops_creator,
|
||||
num_partitions_, partitioned_input.state().partition_id,
|
||||
partitioned_input.state().next_channel_id, module_,
|
||||
partitioned_input.state().b);
|
||||
|
||||
result->set_sharding(hlo->sharding());
|
||||
auto partitioned_fft =
|
||||
PartitionedHlo(result, hlo->shape(), partitioned_input.state());
|
||||
SetPartitionedHlo(hlo, partitioned_fft);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
@ -382,6 +382,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleDot(HloInstruction* hlo) override;
|
||||
Status HandleDynamicSlice(HloInstruction* hlo) override;
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
|
||||
Status HandleFft(HloInstruction* hlo) override;
|
||||
Status HandleGather(HloInstruction* hlo) override;
|
||||
Status HandleGetTupleElement(HloInstruction* hlo) override;
|
||||
Status HandleInfeed(HloInstruction* hlo) override;
|
||||
|
||||
@ -6134,6 +6134,43 @@ ENTRY entry {
|
||||
op::Shape("f32[8,105,210,32]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Fft3D) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
constant = c64[1,1,6]
|
||||
constant({{{(0,0),(1,1),(2,2),(3,3),(4,4),(5,5)}}}),
|
||||
sharding={devices=[1,1,2]0,1}
|
||||
ROOT fft = c64[1,1,6] fft(c64[1,1,6] constant), fft_type=FFT, fft_length={6},
|
||||
sharding={devices=[1,1,2]0,1}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
auto input = AllOf(op::DynamicSlice(op::Constant(), op::Constant(),
|
||||
op::Constant(), op::Reshape()),
|
||||
op::Shape("c64[1,1,3]"));
|
||||
auto padded_input =
|
||||
AllOf(op::DynamicSlice(
|
||||
op::Concatenate(input, op::CollectivePermute(op::Slice())),
|
||||
op::Constant(), op::Constant(), op::Reshape()),
|
||||
op::Shape("c64[1,1,4]"));
|
||||
|
||||
auto shuffled_input =
|
||||
AllOf(op::Slice(op::AllToAll(op::Dot(padded_input, op::Convert()))),
|
||||
op::Shape("c64[1,1,3]"));
|
||||
|
||||
auto local_fft = AllOf(op::Fft(shuffled_input), op::Shape("c64[1,1,3]"));
|
||||
|
||||
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple(
|
||||
_, op::Multiply(local_fft, op::Exp()), _, _, _))),
|
||||
op::Shape("c64[1,1,3]")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user