From 384a0fb075b2ec5c780119f98869c85810913c93 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 20 Jun 2017 17:25:07 -0700 Subject: [PATCH] Add while hlo test. PiperOrigin-RevId: 159638639 --- tensorflow/compiler/xla/tests/while_test.cc | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 2df91974282..beaed374658 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -81,6 +81,40 @@ TEST_F(WhileTest, WhileWithScalarResult) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { + auto result_shape = ShapeUtil::MakeShape(S32, {}); + auto orig_shape = ShapeUtil::MakeShape(S32, {2}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(5), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR0(1); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Reduce(builder.ConstantR1(2, 1), + builder.ConstantR0(0), + CreateScalarAddComputation(S32, &builder), {0}); + auto result = builder.While(condition, body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 5, {}); +} + // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results.