Add while hlo test.
PiperOrigin-RevId: 159638639
This commit is contained in:
parent
a4660cce81
commit
384a0fb075
@ -81,6 +81,40 @@ TEST_F(WhileTest, WhileWithScalarResult) {
|
||||
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||
}
|
||||
|
||||
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
|
||||
auto result_shape = ShapeUtil::MakeShape(S32, {});
|
||||
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
|
||||
|
||||
// Create a computation for the condition: repeat for 5 iterations.
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
builder.Gt(builder.ConstantR0<int32>(5), prev);
|
||||
condition = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a computation for the body: add 1 to the result variable.
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto input = builder.ConstantR0<int32>(1);
|
||||
auto result = builder.Add(input, prev);
|
||||
body = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a While node with computations for the condition and the body.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
|
||||
builder.ConstantR0<int32>(0),
|
||||
CreateScalarAddComputation(S32, &builder), {0});
|
||||
auto result = builder.While(condition, body, init);
|
||||
auto shape = builder.GetShape(result).ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||
}
|
||||
|
||||
// Tests a while node when the result type T is a vector.
|
||||
//
|
||||
// All constants are chosen to produce exact results.
|
||||
|
Loading…
Reference in New Issue
Block a user