[XLA] Add tests for select ops and while loops that produce tuples that contain predicates.
PiperOrigin-RevId: 159645900
This commit is contained in:
parent
980d3f2be6
commit
a4a4698323
@ -151,6 +151,31 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
|
||||
ComputationBuilder b(client_, TestName());
|
||||
ComputationDataHandle v1, v2;
|
||||
|
||||
for (bool direction : {false, true}) {
|
||||
std::unique_ptr<GlobalData> v1_data =
|
||||
CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1",
|
||||
/*builder=*/&b, /*data_handle=*/&v1);
|
||||
std::unique_ptr<GlobalData> v2_data =
|
||||
CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
|
||||
/*builder=*/&b, /*data_handle=*/&v2);
|
||||
auto v1_gt = b.Gt(v1, v2); // false
|
||||
auto v2_gt = b.Gt(v2, v1); // true
|
||||
auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true}
|
||||
auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false}
|
||||
auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
|
||||
auto expected =
|
||||
Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
|
||||
Literal::CreateR0<bool>(!direction).get()});
|
||||
|
||||
ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
|
||||
error_spec_);
|
||||
}
|
||||
}
|
||||
|
||||
// Builds two new tuples from an existing tuple (by means of GetTupleElement),
|
||||
// then adds up the components of the new tuples.
|
||||
XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
|
||||
|
@ -115,6 +115,36 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
|
||||
ComputeAndCompareR0<int32>(&builder, 5, {});
|
||||
}
|
||||
|
||||
TEST_F(WhileTest, WhileWithPredicateResult) {
|
||||
auto result_shape = ShapeUtil::MakeShape(PRED, {});
|
||||
|
||||
// Create a computation for the condition: run until condition is true.
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
builder.Ne(builder.ConstantR0<bool>(true), prev);
|
||||
condition = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a computation for the body: or condition with true.
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto result = builder.LogicalOr(prev, builder.ConstantR0<bool>(true));
|
||||
body = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a While node with computations for the condition and the body.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto init = builder.Ne(builder.ConstantR0<bool>(false),
|
||||
builder.ConstantR0<bool>(true));
|
||||
auto result = builder.While(condition, body, init);
|
||||
|
||||
ComputeAndCompareR0<bool>(&builder, true, {});
|
||||
}
|
||||
|
||||
// Tests a while node when the result type T is a vector.
|
||||
//
|
||||
// All constants are chosen to produce exact results.
|
||||
@ -282,6 +312,53 @@ TEST_F(WhileTest, WhileWithTupleResult) {
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
|
||||
}
|
||||
|
||||
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
|
||||
std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
|
||||
ShapeUtil::MakeShape(PRED, {})};
|
||||
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
|
||||
|
||||
// Create a computation for the condition.
|
||||
// Repeat for 5 iterations.
|
||||
Computation condition;
|
||||
{
|
||||
ComputationBuilder builder(client_, "condition");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto iteration = builder.GetTupleElement(prev, 0);
|
||||
builder.Gt(builder.ConstantR0<int32>(5), iteration);
|
||||
condition = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a computation for the body.
|
||||
// Add 1 to the iteration variable and or the predicate with true
|
||||
Computation body;
|
||||
{
|
||||
ComputationBuilder builder(client_, "body");
|
||||
auto prev = builder.Parameter(0, result_shape, "prev");
|
||||
auto iteration = builder.GetTupleElement(prev, 0);
|
||||
auto pred = builder.GetTupleElement(prev, 1);
|
||||
auto new_pred = builder.LogicalOr(pred, builder.ConstantR0<bool>(true));
|
||||
auto result = builder.Tuple(
|
||||
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
|
||||
body = builder.Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
// Create a While node with computations for the condition and the body.
|
||||
ComputationBuilder builder(client_, "while");
|
||||
auto init = builder.Tuple({builder.ConstantR0<int32>(0),
|
||||
builder.Ne(builder.ConstantR0<bool>(false),
|
||||
builder.ConstantR0<bool>(true))});
|
||||
auto result = builder.While(condition, body, init);
|
||||
VLOG(2) << "while = "
|
||||
<< ShapeUtil::HumanString(
|
||||
*builder.GetShape(result).ConsumeValueOrDie());
|
||||
|
||||
auto expected_counter = Literal::CreateR0<int32>(5);
|
||||
auto expected_predicate = Literal::CreateR0<bool>(true);
|
||||
auto expected =
|
||||
Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
|
||||
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
|
||||
}
|
||||
|
||||
// Tests two while nodes when the result type T is a Tuple and the second
|
||||
// while node uses the result of the first while node which is used in two
|
||||
// nodes.
|
||||
|
Loading…
Reference in New Issue
Block a user