Add while hlo test.

PiperOrigin-RevId: 159638639
This commit is contained in:
A. Unique TensorFlower 2017-06-20 17:25:07 -07:00 committed by TensorFlower Gardener
parent a4660cce81
commit 384a0fb075

View File

@ -81,6 +81,40 @@ TEST_F(WhileTest, WhileWithScalarResult) {
ComputeAndCompareR0<int32>(&builder, 5, {});
}
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
// Create a computation for the condition: repeat for 5 iterations.
Computation condition;
{
ComputationBuilder builder(client_, "condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
Computation body;
{
ComputationBuilder builder(client_, "body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
auto result = builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
ComputationBuilder builder(client_, TestName());
auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
builder.ConstantR0<int32>(0),
CreateScalarAddComputation(S32, &builder), {0});
auto result = builder.While(condition, body, init);
auto shape = builder.GetShape(result).ConsumeValueOrDie();
ComputeAndCompareR0<int32>(&builder, 5, {});
}
// Tests a while node when the result type T is a vector.
//
// All constants are chosen to produce exact results.