[XLA] Add tests for select ops and while loops that produce tuples that contain predicates.

PiperOrigin-RevId: 159645900
This commit is contained in:
A. Unique TensorFlower 2017-06-20 18:49:28 -07:00 committed by TensorFlower Gardener
parent 980d3f2be6
commit a4a4698323
2 changed files with 102 additions and 0 deletions

View File

@ -151,6 +151,31 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
ComputationBuilder b(client_, TestName());
ComputationDataHandle v1, v2;
for (bool direction : {false, true}) {
std::unique_ptr<GlobalData> v1_data =
CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1",
/*builder=*/&b, /*data_handle=*/&v1);
std::unique_ptr<GlobalData> v2_data =
CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&b, /*data_handle=*/&v2);
auto v1_gt = b.Gt(v1, v2); // false
auto v2_gt = b.Gt(v2, v1); // true
auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true}
auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false}
auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected =
Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
Literal::CreateR0<bool>(!direction).get()});
ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
// Builds two new tuples from an existing tuple (by means of GetTupleElement),
// then adds up the components of the new tuples.
XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {

View File

@ -115,6 +115,36 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
ComputeAndCompareR0<int32>(&builder, 5, {});
}
TEST_F(WhileTest, WhileWithPredicateResult) {
auto result_shape = ShapeUtil::MakeShape(PRED, {});
// Create a computation for the condition: run until condition is true.
Computation condition;
{
ComputationBuilder builder(client_, "condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Ne(builder.ConstantR0<bool>(true), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: or condition with true.
Computation body;
{
ComputationBuilder builder(client_, "body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto result = builder.LogicalOr(prev, builder.ConstantR0<bool>(true));
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
ComputationBuilder builder(client_, TestName());
auto init = builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true));
auto result = builder.While(condition, body, init);
ComputeAndCompareR0<bool>(&builder, true, {});
}
// Tests a while node when the result type T is a vector.
//
// All constants are chosen to produce exact results.
@ -282,6 +312,53 @@ TEST_F(WhileTest, WhileWithTupleResult) {
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeShape(PRED, {})};
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
// Create a computation for the condition.
// Repeat for 5 iterations.
Computation condition;
{
ComputationBuilder builder(client_, "condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add 1 to the iteration variable and or the predicate with true
Computation body;
{
ComputationBuilder builder(client_, "body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto pred = builder.GetTupleElement(prev, 1);
auto new_pred = builder.LogicalOr(pred, builder.ConstantR0<bool>(true));
auto result = builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
ComputationBuilder builder(client_, "while");
auto init = builder.Tuple({builder.ConstantR0<int32>(0),
builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true))});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
*builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_predicate = Literal::CreateR0<bool>(true);
auto expected =
Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
}
// Tests two while nodes when the result type T is a Tuple and the second
// while node uses the result of the first while node which is used in two
// nodes.