Fix xla/service tests with new Slice API (includes strides)
This commit is contained in:
parent
4f513e0555
commit
fa3884883d
@ -522,7 +522,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
|
||||
HloInstruction* empty_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}));
|
||||
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
|
||||
Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
|
||||
builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
|
||||
@ -553,7 +553,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
|
||||
HloInstruction* empty_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}));
|
||||
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
|
||||
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
|
||||
builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
result_shape, {empty_literal, empty_slice}, 0));
|
||||
@ -1134,7 +1134,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
|
||||
0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
|
||||
/*limit_indices=*/{dim0, dim1}));
|
||||
/*limit_indices=*/{dim0, dim1}, /*slices=*/{1, 1}));
|
||||
|
||||
HloModule module(TestName());
|
||||
HloComputation* computation = module.AddEntryComputation(builder.Build());
|
||||
@ -1539,7 +1539,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
|
||||
|
||||
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
|
||||
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}));
|
||||
slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
|
||||
|
||||
HloModule module(TestName());
|
||||
auto computation = module.AddEntryComputation(builder.Build());
|
||||
|
@ -730,7 +730,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
|
||||
auto negate = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
|
||||
auto broadcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
|
||||
|
||||
@ -762,7 +762,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
|
||||
auto negate = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
|
||||
auto broadcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
|
||||
@ -799,7 +799,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
|
||||
auto tuple_element = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}));
|
||||
HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
|
||||
auto broadcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
|
||||
@ -834,7 +834,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
// Slice output is 10 elements.
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
|
||||
// Broadcast output is 40 elements.
|
||||
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
|
||||
@ -866,7 +866,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
|
||||
auto negate = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
|
||||
// Broadcast output is 40 elements.
|
||||
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
|
||||
@ -903,7 +903,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
// Slice output is 10 elements.
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
|
||||
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
|
||||
// Broadcast output is 40 elements.
|
||||
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
|
||||
|
@ -590,7 +590,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
if (update_uses_tuple_element1) {
|
||||
// Create a slice instruction as an additional user of 'gte1'.
|
||||
slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}));
|
||||
HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
|
||||
update = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
update_shape, HloOpcode::kAdd, update, slice));
|
||||
}
|
||||
|
@ -153,6 +153,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
|
||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||
const int64 slice_start[] = {4, 2, 3, 1, 5};
|
||||
const int64 slice_limits[] = {10, 8, 6, 5, 9};
|
||||
const int64 slice_strides[] = {1, 1, 1, 1, 1};
|
||||
TF_ASSIGN_OR_ASSERT_OK(auto literal,
|
||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||
@ -160,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
shape, literal_instruction, slice_start, slice_limits));
|
||||
shape, literal_instruction, slice_start, slice_limits, slice_strides));
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
|
@ -67,7 +67,8 @@ class HloRematerializationTest : public HloTestBase {
|
||||
/*dimension=*/0));
|
||||
auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
vec1_shape_, concat_1, /*start_indices=*/{0},
|
||||
/*limit_indices=*/{1}));
|
||||
/*limit_indices=*/{1},
|
||||
/*strides=*/{1}));
|
||||
auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
|
||||
/*dimension=*/0));
|
||||
@ -75,7 +76,8 @@ class HloRematerializationTest : public HloTestBase {
|
||||
// which is necessary to use this computation in a while.
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
|
||||
/*start_indices=*/{0},
|
||||
/*limit_indices=*/{1}));
|
||||
/*limit_indices=*/{1},
|
||||
/*strides=*/{1}));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -103,7 +105,8 @@ class HloRematerializationTest : public HloTestBase {
|
||||
HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
|
||||
auto slice_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
|
||||
/*limit_indices=*/{1}));
|
||||
/*limit_indices=*/{1},
|
||||
/*strides=*/{1}));
|
||||
auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
vec1_shape_, while_cond, while_body, slice_1));
|
||||
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
|
||||
@ -111,7 +114,8 @@ class HloRematerializationTest : public HloTestBase {
|
||||
/*dimension=*/0));
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
|
||||
/*start_indices=*/{0},
|
||||
/*limit_indices=*/{1}));
|
||||
/*limit_indices=*/{1},
|
||||
/*strides=*/{1}));
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
@ -353,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
|
||||
/*dimension=*/0));
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
vec1024_shape_, concat, /*start_indices=*/{0},
|
||||
/*limit_indices=*/{1024}));
|
||||
/*limit_indices=*/{1024}, /*slices=*/{1}));
|
||||
subcomputation = module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
|
||||
@ -469,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
||||
/*dimension=*/0));
|
||||
builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
vec1024_shape_, concat, /*start_indices=*/{0},
|
||||
/*limit_indices=*/{1024}));
|
||||
/*limit_indices=*/{1024}, /*slices=*/{1}));
|
||||
subcomputation = module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
|
||||
|
@ -585,7 +585,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
|
||||
if (add_additional_gte0_user) {
|
||||
// Create 'slice' as an additional user of 'input'.
|
||||
auto slice = builder.AddInstruction(
|
||||
HloInstruction::CreateSlice(update_shape, input, {0}, {3}));
|
||||
HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
|
||||
// Modify 'update' to take 'slice' output.
|
||||
update = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
update_shape, HloOpcode::kAdd, update, slice));
|
||||
|
Loading…
Reference in New Issue
Block a user