Support partition Sort in TopK when input partitioned at sort dimension.
PiperOrigin-RevId: 314174499 Change-Id: I8fbac47edf5a2691c5a51aacda885b0300b53247
This commit is contained in:
parent
f3930469e4
commit
1a430ba06b
@ -43,6 +43,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_query",
|
||||
"//tensorflow/compiler/xla/service:hlo_sharding_util",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/core/platform:numbers",
|
||||
|
||||
@ -1282,6 +1282,106 @@ Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
|
||||
HloSharding sharding = hlo->sharding();
|
||||
// Special handling for sort in TopK when first operand partitioined at
|
||||
// sort dimension.
|
||||
auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
|
||||
if (k.has_value()) {
|
||||
// When the first operand partitioned at sort dimension:
|
||||
// 1. Partition sort computation to different partitions;
|
||||
// 2. Slice TopK value and index from different partitions;
|
||||
// 3. Gather and replicate value and index from different partitions,
|
||||
// the shape of replicated value and index will be
|
||||
// [batch_size, ..., partition_count * k, ...];
|
||||
// 4. Final sort uses replicated value and index from different partitions
|
||||
// as input.
|
||||
// GetTupleElement and Slice after the non-partitoned sort won't change
|
||||
// at this point, as HandleGetTupleElement and HandleSlice will update them.
|
||||
HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
|
||||
const int64 sort_dim = sort->sort_dimension();
|
||||
auto input = hlo->operand(0);
|
||||
auto index = hlo->operand(1);
|
||||
const HloSharding& input_sharding = input->sharding();
|
||||
const int64 partition_count =
|
||||
input_sharding.tile_assignment().dim(sort_dim);
|
||||
const int64 input_size = input->shape().dimensions(sort_dim);
|
||||
const int64 per_partition_size = CeilOfRatio(input_size, partition_count);
|
||||
const auto element_type = input->shape().element_type();
|
||||
const auto index_type = index->shape().element_type();
|
||||
|
||||
// Partition and pad input and index.
|
||||
// Pad input with minimal value.
|
||||
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
|
||||
CreateFirstWithType(element_type, &b_));
|
||||
// Pad index with max value.
|
||||
auto partitioned_index =
|
||||
GetPartitionedHlo(index)
|
||||
.Reshard(input_sharding)
|
||||
.PadWithValue(CreateLastWithType(index_type, &b_));
|
||||
|
||||
// Each partition needs to do TopK separately, thus the base shape
|
||||
// becomes the padded shape.
|
||||
std::vector<int64> replicated_dimensions(
|
||||
input->shape().dimensions().begin(), input->shape().dimensions().end());
|
||||
replicated_dimensions[sort_dim] = per_partition_size * partition_count;
|
||||
const Shape replicated_shape = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
|
||||
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
|
||||
|
||||
// Partition original topk to different shards.
|
||||
auto topk_sharding =
|
||||
input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
|
||||
auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
|
||||
auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
|
||||
|
||||
// Get value from first sort.
|
||||
HloInstruction* value_gte =
|
||||
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
topk->shape().tuple_shapes(0), topk, 0));
|
||||
HloInstruction* index_gte =
|
||||
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
topk->shape().tuple_shapes(1), topk, 1));
|
||||
|
||||
// Slice top K value from the first partitioned sort.
|
||||
replicated_dimensions[sort_dim] = k.value() * partition_count;
|
||||
auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
|
||||
slice_input->set_sharding(input_sharding);
|
||||
PartitionedHlo partitioned_slice_input(
|
||||
slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
|
||||
MakePartitioningState());
|
||||
// Reshard value to be replicated.
|
||||
auto replicated_slice_input =
|
||||
partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
|
||||
|
||||
// Slice top K index from the first parttioned sort.
|
||||
auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
|
||||
slice_index->set_sharding(input_sharding);
|
||||
PartitionedHlo partitioned_slice_index(
|
||||
slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
|
||||
MakePartitioningState());
|
||||
// Reshard value to be replicated.
|
||||
auto replicated_slice_index =
|
||||
partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
|
||||
|
||||
// Creates replicated sort to do TopK, the input is value and index pairs
|
||||
// from all the partitions.
|
||||
const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(element_type, replicated_dimensions),
|
||||
ShapeUtil::MakeShape(index_type, replicated_dimensions)});
|
||||
auto final_sort = b_.AddInstruction(HloInstruction::CreateSort(
|
||||
final_topk_shape, sort_dim,
|
||||
{replicated_slice_input, replicated_slice_index}, sort->to_apply(),
|
||||
sort->is_stable()));
|
||||
final_sort->set_sharding(HloSharding::Replicate()
|
||||
.GetTupleSharding(final_sort->shape())
|
||||
.ValueOrDie());
|
||||
PartitionedHlo replicated_sort(final_sort, final_topk_shape,
|
||||
MakePartitioningState());
|
||||
SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (hlo->shape().IsTuple()) {
|
||||
// Check that all elements are sharded in the same way.
|
||||
if (hlo->shape().tuple_shapes_size() == 0) {
|
||||
@ -1373,16 +1473,8 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
auto input = hlo->operand(0);
|
||||
const auto element_type = input->shape().element_type();
|
||||
|
||||
// Pad input with minimal value.
|
||||
auto min_value = b_.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MinValue(element_type)));
|
||||
// TODO(wangtao): add test to see if -NaN < -Inf in BF16.
|
||||
if (element_type == F32) {
|
||||
auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
|
||||
min_value = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<float>(-float_pad_value)));
|
||||
}
|
||||
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(min_value);
|
||||
auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
|
||||
CreateFirstWithType(element_type, &b_));
|
||||
|
||||
// Each partition needs to do TopK separately, thus the base shape
|
||||
// becomes [batch_size, k * shard_count].
|
||||
@ -1476,24 +1568,12 @@ Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
b_.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(),
|
||||
1));
|
||||
const Shape& hlo_shape = sort_value_gte->shape();
|
||||
auto hlo_dims = hlo_shape.dimensions();
|
||||
std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
|
||||
std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
|
||||
std::vector<int64> strides(hlo_shape.dimensions_size(), sort_dim);
|
||||
limit_indices[sort_dim] = k;
|
||||
auto output_shape = hlo_shape;
|
||||
output_shape.set_dimensions(sort_dim, k);
|
||||
// Slice value from final sort.
|
||||
HloInstruction* slice_sort_value =
|
||||
b_.AddInstruction(HloInstruction::CreateSlice(
|
||||
output_shape, sort_value_gte, start_indices, limit_indices, strides));
|
||||
SliceFirstK(sort_value_gte, &b_, sort_dim, k);
|
||||
// Slice index from final sort.
|
||||
auto index_output_shape = sort_index_gte->shape();
|
||||
index_output_shape.set_dimensions(sort_dim, k);
|
||||
HloInstruction* slice_index_value = b_.AddInstruction(
|
||||
HloInstruction::CreateSlice(index_output_shape, sort_index_gte,
|
||||
start_indices, limit_indices, strides));
|
||||
HloInstruction* slice_index_value =
|
||||
SliceFirstK(sort_index_gte, &b_, sort_dim, k);
|
||||
auto create_tuple = b_.AddInstruction(
|
||||
HloInstruction::CreateTuple({slice_sort_value, slice_index_value}));
|
||||
create_tuple->set_sharding(HloSharding::Replicate());
|
||||
|
||||
@ -1947,6 +1947,385 @@ ENTRY %cluster_2013453984438090939__.47
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
|
||||
s32[], p.1.rhs.12: s32[]) -> pred[] {
|
||||
%p.1.lhs.11 = s32[] parameter(2)
|
||||
%p.1.rhs.12 = s32[] parameter(3)
|
||||
%p.0.lhs.9 = bf16[] parameter(0)
|
||||
%convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
|
||||
%bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
|
||||
%constant.20 = s32[] constant(0)
|
||||
%compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
|
||||
direction=LT
|
||||
%constant.15 = u32[] constant(2147483647)
|
||||
%bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
|
||||
%subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
|
||||
%bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
|
||||
%select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
|
||||
%bitcast-convert.16)
|
||||
%p.0.rhs.10 = bf16[] parameter(1)
|
||||
%convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
|
||||
%bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
|
||||
%constant.28 = s32[] constant(0)
|
||||
%compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
|
||||
direction=LT
|
||||
%constant.23 = u32[] constant(2147483647)
|
||||
%bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
|
||||
%subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
|
||||
%bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
|
||||
%select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
|
||||
%bitcast-convert.24)
|
||||
ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
|
||||
direction=GT
|
||||
}
|
||||
|
||||
ENTRY entry
|
||||
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
|
||||
%arg_tuple.1 = bf16[2,209664] parameter(0)
|
||||
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
|
||||
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%sort.32 = (bf16[2,209664], s32[2,209664])
|
||||
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
|
||||
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.33 = bf16[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
|
||||
%get-tuple-element.33), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.35 = s32[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.36 = s32[2,2000] slice(s32[2,209664]
|
||||
%get-tuple-element.35), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
|
||||
tuple(bf16[2,2000] %slice.34, s32[2,2000]
|
||||
%slice.36), sharding={{replicated}, {replicated}},
|
||||
metadata={op_name="XLA_Retvals"}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
|
||||
auto final_sort = FindInstruction(module.get(), "sort.1");
|
||||
EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
|
||||
EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
|
||||
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
|
||||
%p.0.lhs.2566 = bf16[] parameter(0)
|
||||
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
|
||||
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
|
||||
%constant.285 = s32[] constant(0)
|
||||
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
|
||||
direction=LT
|
||||
%constant.286 = u32[] constant(2147483647)
|
||||
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
|
||||
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
|
||||
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
|
||||
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
|
||||
s32[] %bitcast-convert.48)
|
||||
%p.0.rhs.2567 = bf16[] parameter(1)
|
||||
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
|
||||
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
|
||||
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
|
||||
direction=LT
|
||||
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
|
||||
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
|
||||
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
|
||||
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
|
||||
s32[] %bitcast-convert.51)
|
||||
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
|
||||
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
|
||||
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
|
||||
direction=EQ
|
||||
%p.1.lhs.2586 = s32[] parameter(2)
|
||||
%p.1.rhs.2587 = s32[] parameter(3)
|
||||
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
|
||||
direction=LT
|
||||
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
|
||||
pred[] %compare.86)
|
||||
}
|
||||
|
||||
ENTRY entry
|
||||
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
|
||||
%arg_tuple.1 = bf16[2,209664] parameter(0)
|
||||
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
|
||||
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%sort.32 = (bf16[2,209664], s32[2,209664])
|
||||
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
|
||||
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.33 = bf16[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
|
||||
%get-tuple-element.33), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.35 = s32[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.36 = s32[2,2000] slice(s32[2,209664]
|
||||
%get-tuple-element.35), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
|
||||
tuple(bf16[2,2000] %slice.34, s32[2,2000]
|
||||
%slice.36), sharding={{replicated}, {replicated}},
|
||||
metadata={op_name="XLA_Retvals"}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
|
||||
auto final_sort = FindInstruction(module.get(), "sort.1");
|
||||
EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
|
||||
EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
|
||||
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
|
||||
%p.0.lhs.2566 = bf16[] parameter(0)
|
||||
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
|
||||
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
|
||||
%constant.285 = s32[] constant(0)
|
||||
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
|
||||
direction=LT
|
||||
%constant.286 = u32[] constant(2147483647)
|
||||
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
|
||||
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
|
||||
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
|
||||
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
|
||||
s32[] %bitcast-convert.48)
|
||||
%p.0.rhs.2567 = bf16[] parameter(1)
|
||||
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
|
||||
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
|
||||
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
|
||||
direction=LT
|
||||
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
|
||||
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
|
||||
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
|
||||
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
|
||||
s32[] %bitcast-convert.51)
|
||||
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
|
||||
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
|
||||
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
|
||||
direction=EQ
|
||||
%p.1.lhs.2586 = s32[] parameter(2)
|
||||
%p.1.rhs.2587 = s32[] parameter(3)
|
||||
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
|
||||
direction=LT
|
||||
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
|
||||
pred[] %compare.86)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%arg_tuple.1 = bf16[2,209664] parameter(0)
|
||||
%arg_tuple.2 = s32[2,209664] parameter(1)
|
||||
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
|
||||
%sort.32 = (bf16[2,209664], s32[2,209664])
|
||||
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
|
||||
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.33 = bf16[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
|
||||
%get-tuple-element.33), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.35 = s32[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.36 = s32[2,2000] slice(s32[2,209664]
|
||||
%get-tuple-element.35), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
|
||||
tuple(bf16[2,2000] %slice.34, s32[2,2000]
|
||||
%slice.36), sharding={{replicated}, {replicated}},
|
||||
metadata={op_name="XLA_Retvals"}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
std::cout << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
|
||||
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
|
||||
%p.0.lhs.2566 = bf16[] parameter(0)
|
||||
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
|
||||
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
|
||||
%constant.285 = s32[] constant(0)
|
||||
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
|
||||
direction=LT
|
||||
%constant.286 = u32[] constant(2147483647)
|
||||
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
|
||||
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
|
||||
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
|
||||
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
|
||||
s32[] %bitcast-convert.48)
|
||||
%p.0.rhs.2567 = bf16[] parameter(1)
|
||||
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
|
||||
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
|
||||
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
|
||||
direction=LT
|
||||
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
|
||||
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
|
||||
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
|
||||
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
|
||||
s32[] %bitcast-convert.51)
|
||||
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
|
||||
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
|
||||
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
|
||||
direction=EQ
|
||||
%p.1.lhs.2586 = s32[] parameter(2)
|
||||
%p.1.rhs.2587 = s32[] parameter(3)
|
||||
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
|
||||
direction=LT
|
||||
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
|
||||
pred[] %compare.86)
|
||||
}
|
||||
|
||||
ENTRY entry
|
||||
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
|
||||
%arg_tuple.1 = bf16[2,209664] parameter(0)
|
||||
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
|
||||
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%sort.32 = (bf16[2,209664], s32[2,209664])
|
||||
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
|
||||
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.33 = bf16[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
|
||||
%get-tuple-element.33), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.35 = s32[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.36 = s32[2,2000] slice(s32[2,209664]
|
||||
%get-tuple-element.35), slice={[0:2], [0:2000]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
|
||||
tuple(bf16[2,2000] %slice.34, s32[2,2000]
|
||||
%slice.36), sharding={{replicated}, {replicated}},
|
||||
metadata={op_name="XLA_Retvals"}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
std::cout << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
|
||||
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
|
||||
%p.0.lhs.2566 = bf16[] parameter(0)
|
||||
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
|
||||
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
|
||||
%constant.285 = s32[] constant(0)
|
||||
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
|
||||
direction=LT
|
||||
%constant.286 = u32[] constant(2147483647)
|
||||
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
|
||||
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
|
||||
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
|
||||
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
|
||||
s32[] %bitcast-convert.48)
|
||||
%p.0.rhs.2567 = bf16[] parameter(1)
|
||||
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
|
||||
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
|
||||
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
|
||||
direction=LT
|
||||
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
|
||||
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
|
||||
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
|
||||
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
|
||||
s32[] %bitcast-convert.51)
|
||||
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
|
||||
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
|
||||
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
|
||||
direction=EQ
|
||||
%p.1.lhs.2586 = s32[] parameter(2)
|
||||
%p.1.rhs.2587 = s32[] parameter(3)
|
||||
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
|
||||
direction=LT
|
||||
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
|
||||
pred[] %compare.86)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%arg_tuple.1 = bf16[2,209664] parameter(0)
|
||||
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
|
||||
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%sort.32 = (bf16[2,209664], s32[2,209664])
|
||||
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
|
||||
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.33 = bf16[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.34 = bf16[1,209664] slice(bf16[2,209664]
|
||||
%get-tuple-element.33), slice={[0:1], [0:209664]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%get-tuple-element.35 = s32[2,209664]
|
||||
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
|
||||
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
%slice.36 = s32[1,209664] slice(s32[2,209664]
|
||||
%get-tuple-element.35), slice={[0:1], [0:209664]},
|
||||
metadata={op_type="TopKV2" op_name="TopKV2"}
|
||||
ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
|
||||
tuple(bf16[1,209664] %slice.34, s32[1,209664]
|
||||
%slice.36), sharding={{replicated}, {replicated}},
|
||||
metadata={op_name="XLA_Retvals"}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ShardableTranspose) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -19,10 +19,13 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -702,5 +705,170 @@ HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
|
||||
return reshard_window->sharded_input;
|
||||
}
|
||||
|
||||
bool IsNanSafeGt(HloComputation* comp) {
|
||||
namespace m = match;
|
||||
auto match_bitcast_f32 = [](int64 parameter_number) {
|
||||
auto param = m::Parameter(parameter_number)
|
||||
.WithShape(m::Shape().WithElementType(F32));
|
||||
auto param_s32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
|
||||
auto param_u32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
|
||||
return m::Select(
|
||||
m::Lt(param_s32, m::ConstantScalar(0)),
|
||||
m::BitcastConvert(
|
||||
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
|
||||
param_u32))
|
||||
.WithShape(m::Shape().WithElementType(S32)),
|
||||
param_s32);
|
||||
};
|
||||
auto match_bitcast_bf16 = [](int64 parameter_number) {
|
||||
auto param = m::Convert(m::Parameter(parameter_number)
|
||||
.WithShape(m::Shape().WithElementType(BF16)))
|
||||
.WithShape(m::Shape().WithElementType(F32));
|
||||
auto param_s32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
|
||||
auto param_u32 =
|
||||
m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
|
||||
return m::Select(
|
||||
m::Lt(param_s32, m::ConstantScalar(0)),
|
||||
m::BitcastConvert(
|
||||
m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
|
||||
param_u32))
|
||||
.WithShape(m::Shape().WithElementType(S32)),
|
||||
param_s32);
|
||||
};
|
||||
// If root instruction is kSelect and compares indices if values are equal.
|
||||
if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
|
||||
return Match(comp->root_instruction()->operand(2),
|
||||
m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
|
||||
Match(comp->root_instruction()->operand(2),
|
||||
m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
|
||||
}
|
||||
return Match(comp->root_instruction(),
|
||||
m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
|
||||
Match(comp->root_instruction(),
|
||||
m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
|
||||
}
|
||||
|
||||
absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) {
|
||||
HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
|
||||
if (sort == nullptr || sort->operand_count() != 2) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (!IsNanSafeGt(sort->to_apply())) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
HloInstruction* data = sort->mutable_operand(0);
|
||||
HloIotaInstruction* iota =
|
||||
DynCast<HloIotaInstruction>(sort->mutable_operand(1));
|
||||
const PrimitiveType element_type = data->shape().element_type();
|
||||
if (iota == nullptr || iota->shape().element_type() != S32 ||
|
||||
iota->opcode() != HloOpcode::kIota ||
|
||||
iota->iota_dimension() != sort->sort_dimension()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
const int64 sort_dim = sort->sort_dimension();
|
||||
|
||||
if (element_type != F32 && element_type != BF16 && element_type != S32 &&
|
||||
element_type != U32) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
absl::optional<int64> k;
|
||||
for (HloInstruction* gte : sort->users()) {
|
||||
if (gte->opcode() != HloOpcode::kGetTupleElement) {
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
|
||||
const HloInstruction* slice = gte->users()[0];
|
||||
if (slice->opcode() != HloOpcode::kSlice) {
|
||||
// Non-slice user means we are not doing a TopK
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
|
||||
absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
|
||||
// Strided slice or slicing at the beginning isn't supported.
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
for (int64 dim = 0; dim < data->shape().dimensions_size(); dim++) {
|
||||
if (dim == sort_dim) {
|
||||
continue;
|
||||
}
|
||||
if (slice->slice_limits(dim) !=
|
||||
slice->operand(0)->shape().dimensions(dim)) {
|
||||
// Slicing along the other dimension isn't supported.
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!k.has_value()) {
|
||||
k = slice->slice_limits(sort_dim);
|
||||
} else if (k != slice->slice_limits(sort_dim)) {
|
||||
// Different k for the different operands isn't supported.
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == absl::nullopt || !supported) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Only support when sort dim is sharded.
|
||||
if (!data->has_sharding()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const HloSharding& sharding = sort->operand(0)->sharding();
|
||||
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Check if partitioned at sort dimension.
|
||||
for (int64 dim : sort->dimensions()) {
|
||||
if (sharding.tile_assignment().dim(dim) > 1) {
|
||||
if (dim != sort_dim) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if partition size is smaller than k.
|
||||
const int64 shard_count = sharding.tile_assignment().dim(sort_dim);
|
||||
|
||||
if (shard_count <= 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim);
|
||||
const int64 per_partition_size = CeilOfRatio(input_size, shard_count);
|
||||
|
||||
if (k.value() >= per_partition_size) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
// Slice first k elements from sort_dim.
|
||||
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
|
||||
int64 slice_dim, int64 k) {
|
||||
const Shape& hlo_shape = hlo->shape();
|
||||
auto hlo_dims = hlo_shape.dimensions();
|
||||
std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
|
||||
std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
|
||||
std::vector<int64> strides(hlo_shape.dimensions_size(), 1);
|
||||
limit_indices[slice_dim] = k;
|
||||
auto output_shape = hlo_shape;
|
||||
output_shape.set_dimensions(slice_dim, k);
|
||||
return builder->AddInstruction(HloInstruction::CreateSlice(
|
||||
output_shape, hlo, start_indices, limit_indices, strides));
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
@ -45,6 +45,24 @@ HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value,
|
||||
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
|
||||
}
|
||||
|
||||
inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) {
|
||||
if (type == F32) {
|
||||
auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
|
||||
return CreateR0WithType(type, -float_pad_value, b);
|
||||
}
|
||||
auto literal = LiteralUtil::MinValue(type);
|
||||
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
|
||||
}
|
||||
|
||||
inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) {
|
||||
if (type == F32) {
|
||||
auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
|
||||
return CreateR0WithType(type, float_pad_value, b);
|
||||
}
|
||||
auto literal = LiteralUtil::MaxValue(type);
|
||||
return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
|
||||
}
|
||||
|
||||
// Create a binary add computation of the given type and add to the module.
|
||||
HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
|
||||
|
||||
@ -234,6 +252,16 @@ absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
|
||||
HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
|
||||
absl::Span<const int64> dims);
|
||||
|
||||
// Check if the computation is GT comparison and safe for NaNs.
|
||||
bool IsNanSafeGt(HloComputation* computation);
|
||||
|
||||
// Return k in TopK when input value is parttioned in the sort dimension.
|
||||
absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
|
||||
|
||||
// Slices the first k elements at slice dimension.
|
||||
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
|
||||
int64 slice_dim, int64 k);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user