Add while hlo test.
PiperOrigin-RevId: 159638639
This commit is contained in:
parent
a4660cce81
commit
384a0fb075
@ -81,6 +81,40 @@ TEST_F(WhileTest, WhileWithScalarResult) {
|
|||||||
ComputeAndCompareR0<int32>(&builder, 5, {});
|
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
|
||||||
|
auto result_shape = ShapeUtil::MakeShape(S32, {});
|
||||||
|
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
|
||||||
|
|
||||||
|
// Create a computation for the condition: repeat for 5 iterations.
|
||||||
|
Computation condition;
|
||||||
|
{
|
||||||
|
ComputationBuilder builder(client_, "condition");
|
||||||
|
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||||
|
builder.Gt(builder.ConstantR0<int32>(5), prev);
|
||||||
|
condition = builder.Build().ConsumeValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a computation for the body: add 1 to the result variable.
|
||||||
|
Computation body;
|
||||||
|
{
|
||||||
|
ComputationBuilder builder(client_, "body");
|
||||||
|
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||||
|
auto input = builder.ConstantR0<int32>(1);
|
||||||
|
auto result = builder.Add(input, prev);
|
||||||
|
body = builder.Build().ConsumeValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a While node with computations for the condition and the body.
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
|
||||||
|
builder.ConstantR0<int32>(0),
|
||||||
|
CreateScalarAddComputation(S32, &builder), {0});
|
||||||
|
auto result = builder.While(condition, body, init);
|
||||||
|
auto shape = builder.GetShape(result).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||||
|
}
|
||||||
|
|
||||||
// Tests a while node when the result type T is a vector.
|
// Tests a while node when the result type T is a vector.
|
||||||
//
|
//
|
||||||
// All constants are chosen to produce exact results.
|
// All constants are chosen to produce exact results.
|
||||||
|
Loading…
Reference in New Issue
Block a user