Fix xla/service tests with new Slice API (includes strides)

This commit is contained in:
DavidNorman 2017-06-14 14:26:11 +01:00 committed by Martin Wicke
parent 4f513e0555
commit fa3884883d
6 changed files with 24 additions and 19 deletions

View File

@ -522,7 +522,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({}))); HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice = HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice( builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42})); ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength}); Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
builder.AddInstruction(HloInstruction::CreateConcatenate( builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
@ -553,7 +553,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({}))); HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice = HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice( builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42})); ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
Shape result_shape = ShapeUtil::MakeShape(F32, {0}); Shape result_shape = ShapeUtil::MakeShape(F32, {0});
builder.AddInstruction(HloInstruction::CreateConcatenate( builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, empty_slice}, 0)); result_shape, {empty_literal, empty_slice}, 0));
@ -1134,7 +1134,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
builder.AddInstruction(HloInstruction::CreateSlice( builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
/*limit_indices=*/{dim0, dim1})); /*limit_indices=*/{dim0, dim1}, /*slices=*/{1, 1}));
HloModule module(TestName()); HloModule module(TestName());
HloComputation* computation = module.AddEntryComputation(builder.Build()); HloComputation* computation = module.AddEntryComputation(builder.Build());
@ -1539,7 +1539,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6})); slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
HloModule module(TestName()); HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build()); auto computation = module.AddEntryComputation(builder.Build());

View File

@ -730,7 +730,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
auto negate = builder.AddInstruction( auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
auto broadcast = builder.AddInstruction( auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
@ -762,7 +762,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
auto negate = builder.AddInstruction( auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
auto broadcast = builder.AddInstruction( auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
@ -799,7 +799,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
auto tuple_element = builder.AddInstruction( auto tuple_element = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0)); HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10})); HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
auto broadcast = builder.AddInstruction( auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
@ -834,7 +834,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
// Slice output is 10 elements. // Slice output is 10 elements.
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
// Broadcast output is 40 elements. // Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
@ -866,7 +866,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
auto negate = builder.AddInstruction( auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
// Broadcast output is 40 elements. // Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
@ -903,7 +903,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
// Slice output is 10 elements. // Slice output is 10 elements.
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
// Broadcast output is 40 elements. // Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));

View File

@ -590,7 +590,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
if (update_uses_tuple_element1) { if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'. // Create a slice instruction as an additional user of 'gte1'.
slice = builder.AddInstruction( slice = builder.AddInstruction(
HloInstruction::CreateSlice(update_shape, gte1, {0}, {3})); HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
update = builder.AddInstruction(HloInstruction::CreateBinary( update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice)); update_shape, HloOpcode::kAdd, update, slice));
} }

View File

@ -153,6 +153,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 dimensions[] = {11, 8, 7, 5, 9}; const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5}; const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_limits[] = {10, 8, 6, 5, 9};
const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSIGN_OR_ASSERT_OK(auto literal, TF_ASSIGN_OR_ASSERT_OK(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>( LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
@ -160,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
HloInstruction::CreateConstant(std::move(literal))); HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction(HloInstruction::CreateSlice( builder.AddInstruction(HloInstruction::CreateSlice(
shape, literal_instruction, slice_start, slice_limits)); shape, literal_instruction, slice_start, slice_limits, slice_strides));
auto module = CreateNewModule(); auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build()); auto computation = module->AddEntryComputation(builder.Build());

View File

@ -67,7 +67,8 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0)); /*dimension=*/0));
auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice( auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
vec1_shape_, concat_1, /*start_indices=*/{0}, vec1_shape_, concat_1, /*start_indices=*/{0},
/*limit_indices=*/{1})); /*limit_indices=*/{1},
/*strides=*/{1}));
auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate( auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1}, ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
/*dimension=*/0)); /*dimension=*/0));
@ -75,7 +76,8 @@ class HloRematerializationTest : public HloTestBase {
// which is necessary to use this computation in a while. // which is necessary to use this computation in a while.
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2, builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
/*start_indices=*/{0}, /*start_indices=*/{0},
/*limit_indices=*/{1})); /*limit_indices=*/{1},
/*strides=*/{1}));
return builder.Build(); return builder.Build();
} }
@ -103,7 +105,8 @@ class HloRematerializationTest : public HloTestBase {
HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
auto slice_1 = builder.AddInstruction( auto slice_1 = builder.AddInstruction(
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
/*limit_indices=*/{1})); /*limit_indices=*/{1},
/*strides=*/{1}));
auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile( auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
vec1_shape_, while_cond, while_body, slice_1)); vec1_shape_, while_cond, while_body, slice_1));
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
@ -111,7 +114,8 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0)); /*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat, builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
/*start_indices=*/{0}, /*start_indices=*/{0},
/*limit_indices=*/{1})); /*limit_indices=*/{1},
/*strides=*/{1}));
return builder.Build(); return builder.Build();
} }
@ -353,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
/*dimension=*/0)); /*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice( builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0}, vec1024_shape_, concat, /*start_indices=*/{0},
/*limit_indices=*/{1024})); /*limit_indices=*/{1024}, /*slices=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build()); subcomputation = module->AddEmbeddedComputation(builder.Build());
} }
@ -469,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
/*dimension=*/0)); /*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice( builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0}, vec1024_shape_, concat, /*start_indices=*/{0},
/*limit_indices=*/{1024})); /*limit_indices=*/{1024}, /*slices=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build()); subcomputation = module->AddEmbeddedComputation(builder.Build());
} }

View File

@ -585,7 +585,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
if (add_additional_gte0_user) { if (add_additional_gte0_user) {
// Create 'slice' as an additional user of 'input'. // Create 'slice' as an additional user of 'input'.
auto slice = builder.AddInstruction( auto slice = builder.AddInstruction(
HloInstruction::CreateSlice(update_shape, input, {0}, {3})); HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
// Modify 'update' to take 'slice' output. // Modify 'update' to take 'slice' output.
update = builder.AddInstruction(HloInstruction::CreateBinary( update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice)); update_shape, HloOpcode::kAdd, update, slice));