diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index a4d96646a14..be3cdbca090 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -151,6 +151,31 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle v1, v2; + + for (bool direction : {false, true}) { + std::unique_ptr v1_data = + CreateR0Parameter(0.0f, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&b, /*data_handle=*/&v1); + std::unique_ptr v2_data = + CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&b, /*data_handle=*/&v2); + auto v1_gt = b.Gt(v1, v2); // false + auto v2_gt = b.Gt(v2, v1); // true + auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} + auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} + auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto expected = + Literal::MakeTuple({Literal::CreateR0(direction).get(), + Literal::CreateR0(!direction).get()}); + + ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + error_spec_); + } +} + // Builds two new tuples from an existing tuple (by means of GetTupleElement), // then adds up the components of the new tuples. XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index beaed374658..ccd2a956589 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -115,6 +115,36 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithPredicateResult) { + auto result_shape = ShapeUtil::MakeShape(PRED, {}); + + // Create a computation for the condition: run until condition is true. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Ne(builder.ConstantR0(true), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: or condition with true. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true)); + auto result = builder.While(condition, body, init); + + ComputeAndCompareR0(&builder, true, {}); +} + // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results. @@ -282,6 +312,53 @@ TEST_F(WhileTest, WhileWithTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +TEST_F(WhileTest, WhileWithPredicateTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(PRED, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and or the predicate with true + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto pred = builder.GetTupleElement(prev, 1); + auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple({builder.ConstantR0(0), + builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true))}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(5); + auto expected_predicate = Literal::CreateR0(true); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes.