Support partition Sort in TopK when input partitioned at sort dimension.

PiperOrigin-RevId: 314174499
Change-Id: I8fbac47edf5a2691c5a51aacda885b0300b53247
This commit is contained in:
A. Unique TensorFlower 2020-06-01 11:58:35 -07:00 committed by TensorFlower Gardener
parent f3930469e4
commit 1a430ba06b
5 changed files with 681 additions and 25 deletions

View File

@ -43,6 +43,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_query",
"//tensorflow/compiler/xla/service:hlo_sharding_util",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/core/platform:numbers",

View File

@ -1282,6 +1282,106 @@ Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
HloSharding sharding = hlo->sharding();
// Special handling for sort in TopK when first operand partitioined at
// sort dimension.
auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
if (k.has_value()) {
// When the first operand partitioned at sort dimension:
// 1. Partition sort computation to different partitions;
// 2. Slice TopK value and index from different partitions;
// 3. Gather and replicate value and index from different partitions,
// the shape of replicated value and index will be
// [batch_size, ..., partition_count * k, ...];
// 4. Final sort uses replicated value and index from different partitions
// as input.
// GetTupleElement and Slice after the non-partitoned sort won't change
// at this point, as HandleGetTupleElement and HandleSlice will update them.
HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
const int64 sort_dim = sort->sort_dimension();
auto input = hlo->operand(0);
auto index = hlo->operand(1);
const HloSharding& input_sharding = input->sharding();
const int64 partition_count =
input_sharding.tile_assignment().dim(sort_dim);
const int64 input_size = input->shape().dimensions(sort_dim);
const int64 per_partition_size = CeilOfRatio(input_size, partition_count);
const auto element_type = input->shape().element_type();
const auto index_type = index->shape().element_type();
// Partition and pad input and index.
// Pad input with minimal value.
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
CreateFirstWithType(element_type, &b_));
// Pad index with max value.
auto partitioned_index =
GetPartitionedHlo(index)
.Reshard(input_sharding)
.PadWithValue(CreateLastWithType(index_type, &b_));
// Each partition needs to do TopK separately, thus the base shape
// becomes the padded shape.
std::vector<int64> replicated_dimensions(
input->shape().dimensions().begin(), input->shape().dimensions().end());
replicated_dimensions[sort_dim] = per_partition_size * partition_count;
const Shape replicated_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
// Partition original topk to different shards.
auto topk_sharding =
input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
// Get value from first sort.
HloInstruction* value_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(0), topk, 0));
HloInstruction* index_gte =
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));
// Slice top K value from the first partitioned sort.
replicated_dimensions[sort_dim] = k.value() * partition_count;
auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
slice_input->set_sharding(input_sharding);
PartitionedHlo partitioned_slice_input(
slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_input =
partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
// Slice top K index from the first parttioned sort.
auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
slice_index->set_sharding(input_sharding);
PartitionedHlo partitioned_slice_index(
slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
MakePartitioningState());
// Reshard value to be replicated.
auto replicated_slice_index =
partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
// Creates replicated sort to do TopK, the input is value and index pairs
// from all the partitions.
const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
auto final_sort = b_.AddInstruction(HloInstruction::CreateSort(
final_topk_shape, sort_dim,
{replicated_slice_input, replicated_slice_index}, sort->to_apply(),
sort->is_stable()));
final_sort->set_sharding(HloSharding::Replicate()
.GetTupleSharding(final_sort->shape())
.ValueOrDie());
PartitionedHlo replicated_sort(final_sort, final_topk_shape,
MakePartitioningState());
SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
return Status::OK();
}
if (hlo->shape().IsTuple()) {
// Check that all elements are sharded in the same way.
if (hlo->shape().tuple_shapes_size() == 0) {
@ -1373,16 +1473,8 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
auto input = hlo->operand(0);
const auto element_type = input->shape().element_type();
// Pad input with minimal value.
auto min_value = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::MinValue(element_type)));
// TODO(wangtao): add test to see if -NaN < -Inf in BF16.
if (element_type == F32) {
auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
min_value = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<float>(-float_pad_value)));
}
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(min_value);
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
CreateFirstWithType(element_type, &b_));
// Each partition needs to do TopK separately, thus the base shape
// becomes [batch_size, k * shard_count].
@ -1476,24 +1568,12 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
1));
const Shape& hlo_shape = sort_value_gte->shape();
auto hlo_dims = hlo_shape.dimensions();
std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
std::vector<int64> strides(hlo_shape.dimensions_size(), sort_dim);
limit_indices[sort_dim] = k;
auto output_shape = hlo_shape;
output_shape.set_dimensions(sort_dim, k);
// Slice value from final sort.
HloInstruction* slice_sort_value =
b_.AddInstruction(HloInstruction::CreateSlice(
output_shape, sort_value_gte, start_indices, limit_indices, strides));
SliceFirstK(sort_value_gte, &b_, sort_dim, k);
// Slice index from final sort.
auto index_output_shape = sort_index_gte->shape();
index_output_shape.set_dimensions(sort_dim, k);
HloInstruction* slice_index_value = b_.AddInstruction(
HloInstruction::CreateSlice(index_output_shape, sort_index_gte,
start_indices, limit_indices, strides));
HloInstruction* slice_index_value =
SliceFirstK(sort_index_gte, &b_, sort_dim, k);
auto create_tuple = b_.AddInstruction(
HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
create_tuple->set_sharding(HloSharding::Replicate());

View File

@ -1947,6 +1947,385 @@ ENTRY %cluster_2013453984438090939__.47
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
}
TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
s32[], p.1.rhs.12: s32[]) -> pred[] {
%p.1.lhs.11 = s32[] parameter(2)
%p.1.rhs.12 = s32[] parameter(3)
%p.0.lhs.9 = bf16[] parameter(0)
%convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
%bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
%constant.20 = s32[] constant(0)
%compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
direction=LT
%constant.15 = u32[] constant(2147483647)
%bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
%subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
%bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
%select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
%bitcast-convert.16)
%p.0.rhs.10 = bf16[] parameter(1)
%convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
%bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
%constant.28 = s32[] constant(0)
%compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
direction=LT
%constant.23 = u32[] constant(2147483647)
%bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
%subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
%bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
%select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
%bitcast-convert.24)
ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
direction=GT
}
ENTRY entry
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
auto final_sort = FindInstruction(module.get(), "sort.1");
EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
}
TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
auto final_sort = FindInstruction(module.get(), "sort.1");
EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
}
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%arg_tuple.2 = s32[2,209664] parameter(1)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
std::cout << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
}
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
std::cout << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
}
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[1,209664] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:1], [0:209664]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[1,209664] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:1], [0:209664]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
tuple(bf16[1,209664] %slice.34, s32[1,209664]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
}
TEST_F(SpmdPartitioningTest, ShardableTranspose) {
const char* const hlo_string = R"(
HloModule module

View File

@ -19,10 +19,13 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@ -702,5 +705,170 @@ HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
return reshard_window->sharded_input;
}
bool IsNanSafeGt(HloComputation* comp) {
namespace m = match;
auto match_bitcast_f32 = [](int64 parameter_number) {
auto param = m::Parameter(parameter_number)
.WithShape(m::Shape().WithElementType(F32));
auto param_s32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
auto param_u32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
return m::Select(
m::Lt(param_s32, m::ConstantScalar(0)),
m::BitcastConvert(
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
param_u32))
.WithShape(m::Shape().WithElementType(S32)),
param_s32);
};
auto match_bitcast_bf16 = [](int64 parameter_number) {
auto param = m::Convert(m::Parameter(parameter_number)
.WithShape(m::Shape().WithElementType(BF16)))
.WithShape(m::Shape().WithElementType(F32));
auto param_s32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
auto param_u32 =
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
return m::Select(
m::Lt(param_s32, m::ConstantScalar(0)),
m::BitcastConvert(
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
param_u32))
.WithShape(m::Shape().WithElementType(S32)),
param_s32);
};
// If root instruction is kSelect and compares indices if values are equal.
if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
return Match(comp->root_instruction()->operand(2),
m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
Match(comp->root_instruction()->operand(2),
m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
}
return Match(comp->root_instruction(),
m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
Match(comp->root_instruction(),
m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
}
absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) {
HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
if (sort == nullptr || sort->operand_count() != 2) {
return absl::nullopt;
}
if (!IsNanSafeGt(sort->to_apply())) {
return absl::nullopt;
}
HloInstruction* data = sort->mutable_operand(0);
HloIotaInstruction* iota =
DynCast<HloIotaInstruction>(sort->mutable_operand(1));
const PrimitiveType element_type = data->shape().element_type();
if (iota == nullptr || iota->shape().element_type() != S32 ||
iota->opcode() != HloOpcode::kIota ||
iota->iota_dimension() != sort->sort_dimension()) {
return absl::nullopt;
}
const int64 sort_dim = sort->sort_dimension();
if (element_type != F32 && element_type != BF16 && element_type != S32 &&
element_type != U32) {
return absl::nullopt;
}
bool supported = true;
absl::optional<int64> k;
for (HloInstruction* gte : sort->users()) {
if (gte->opcode() != HloOpcode::kGetTupleElement) {
supported = false;
break;
}
const HloInstruction* slice = gte->users()[0];
if (slice->opcode() != HloOpcode::kSlice) {
// Non-slice user means we are not doing a TopK
supported = false;
break;
}
if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
// Strided slice or slicing at the beginning isn't supported.
supported = false;
break;
}
for (int64 dim = 0; dim < data->shape().dimensions_size(); dim++) {
if (dim == sort_dim) {
continue;
}
if (slice->slice_limits(dim) !=
slice->operand(0)->shape().dimensions(dim)) {
// Slicing along the other dimension isn't supported.
supported = false;
break;
}
}
if (!k.has_value()) {
k = slice->slice_limits(sort_dim);
} else if (k != slice->slice_limits(sort_dim)) {
// Different k for the different operands isn't supported.
supported = false;
break;
}
}
if (k == absl::nullopt || !supported) {
return absl::nullopt;
}
// Only support when sort dim is sharded.
if (!data->has_sharding()) {
return absl::nullopt;
}
const HloSharding& sharding = sort->operand(0)->sharding();
if (sharding.IsTileMaximal()) {
return absl::nullopt;
}
// Check if partitioned at sort dimension.
for (int64 dim : sort->dimensions()) {
if (sharding.tile_assignment().dim(dim) > 1) {
if (dim != sort_dim) {
return absl::nullopt;
}
}
}
// Checks if partition size is smaller than k.
const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
if (shard_count <= 1) {
return absl::nullopt;
}
const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
if (k.value() >= per_partition_size) {
return absl::nullopt;
}
return k;
}
// Slice first k elements from sort_dim.
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
int64 slice_dim, int64 k) {
const Shape& hlo_shape = hlo->shape();
auto hlo_dims = hlo_shape.dimensions();
std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
std::vector<int64> strides(hlo_shape.dimensions_size(), 1);
limit_indices[slice_dim] = k;
auto output_shape = hlo_shape;
output_shape.set_dimensions(slice_dim, k);
return builder->AddInstruction(HloInstruction::CreateSlice(
output_shape, hlo, start_indices, limit_indices, strides));
}
} // namespace spmd
} // namespace xla

View File

@ -45,6 +45,24 @@ HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value,
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
}
inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) {
if (type == F32) {
auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
return CreateR0WithType(type, -float_pad_value, b);
}
auto literal = LiteralUtil::MinValue(type);
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
}
inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) {
if (type == F32) {
auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
return CreateR0WithType(type, float_pad_value, b);
}
auto literal = LiteralUtil::MaxValue(type);
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
}
// Create a binary add computation of the given type and add to the module.
HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
@ -234,6 +252,16 @@ absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
absl::Span<const int64> dims);
// Check if the computation is GT comparison and safe for NaNs.
bool IsNanSafeGt(HloComputation* computation);
// Return k in TopK when input value is parttioned in the sort dimension.
absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
// Slices the first k elements at slice dimension.
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
int64 slice_dim, int64 k);
} // namespace spmd
} // namespace xla