Brian Patton 9bac04a4de Modifies kConditional to support both predicated and indexed conditionals (i.e. switch statements).
Updates lowerings for both CPU and GPU.

Adds a new tf2xla kernel for the CFv2 functional Case op which lowers to an indexed kConditional.

PiperOrigin-RevId: 235831207
2019-02-26 18:15:22 -08:00

810 lines
32 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <random>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
class ConditionalOpTest : public ClientLibraryTestBase {
protected:
XlaComputation CreateR0ConstantComputation(float value) {
XlaBuilder builder("Constant");
Parameter(&builder, 0, empty_tuple_, "tuple");
ConstantR0<float>(&builder, value);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0IdentityComputation() {
XlaBuilder builder("Identity");
Parameter(&builder, 0, r0f32_, "x");
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateCeilComputation(const Shape& shape) {
XlaBuilder builder("Ceil");
auto param = Parameter(&builder, 0, shape, "param");
Ceil(param);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0CeilComputation() {
return CreateCeilComputation(r0f32_);
}
XlaComputation CreateR1CeilComputation() {
return CreateCeilComputation(r1s2f32_);
}
XlaComputation CreateFloorComputation(const Shape& shape) {
XlaBuilder builder("Floor");
auto param = Parameter(&builder, 0, shape, "param");
Floor(param);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0FloorComputation() {
return CreateFloorComputation(r0f32_);
}
XlaComputation CreateR1FloorComputation() {
return CreateFloorComputation(r1s2f32_);
}
XlaComputation CreateTupleCeilComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
auto x = GetTupleElement(tuple, 0);
auto y = GetTupleElement(tuple, 1);
auto x_ceil = Ceil(x);
auto y_ceil = Ceil(y);
Tuple(&builder, {x_ceil, y_ceil});
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0TupleCeilComputation() {
return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
}
XlaComputation CreateR1TupleCeilComputation() {
return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
}
XlaComputation CreateTupleFloorComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
auto x = GetTupleElement(tuple, 0);
auto y = GetTupleElement(tuple, 1);
auto x_floor = Floor(x);
auto y_floor = Floor(y);
Tuple(&builder, {x_floor, y_floor});
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0TupleFloorComputation() {
return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
}
XlaComputation CreateR1TupleFloorComputation() {
return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
}
XlaComputation CreateTupleAddComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
auto x = GetTupleElement(tuple, 0);
auto y = GetTupleElement(tuple, 1);
Add(x, y);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0TupleAddComputation() {
return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
}
XlaComputation CreateR1TupleAddComputation() {
return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
}
XlaComputation CreateTupleSubComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
auto x = GetTupleElement(tuple, 0);
auto y = GetTupleElement(tuple, 1);
Sub(x, y);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
XlaComputation CreateR0TupleSubComputation() {
return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
}
XlaComputation CreateR1TupleSubComputation() {
return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
}
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})});
Shape empty_tuple_ = ShapeUtil::MakeTupleShape({});
ErrorSpec error_spec_{0.001};
};
// Test fixture to run indexed conditional (switch/case) tests with varying
// number of branches.
class CaseOpTest : public ConditionalOpTest,
public ::testing::WithParamInterface<int> {};
// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
auto true_computation = CreateR0ConstantComputation(56.0f);
auto false_computation = CreateR0ConstantComputation(12.0f);
Conditional(pred, operands, true_computation, operands, false_computation);
ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
}
// Test branch computations that do not take any parameters.
XLA_TEST_P(CaseOpTest, Parameters0) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
&builder, &branch_index);
auto operand = Tuple(&builder, {});
std::vector<XlaOp> operands(num_branches, operand);
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(
CreateR0ConstantComputation(static_cast<float>(i) * 10));
branches_p[i] = &branches[i];
}
Conditional(branch_index, branches_p, operands);
float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
? num_branches - 1
: bi);
ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
error_spec_);
}
}
// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto identity = CreateR0IdentityComputation();
Conditional(pred, operand1, identity, operand2, identity);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test branch computations that take in 1 parameter.
XLA_TEST_P(CaseOpTest, Parameters1) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
&builder, &branch_index);
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
Parameter(sb.get(), 0, r0f32_, "p0"));
return sb->BuildAndNoteError();
};
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
std::vector<XlaOp> operands;
operands.reserve(num_branches);
std::vector<float> expecteds(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i));
branches_p[i] = &branches[i];
auto fi = static_cast<float>(i);
operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
expecteds[i] = 10 * fi + 7 + fi;
}
Conditional(branch_index, branches_p, operands);
float expected = (bi < 0 || bi >= num_branches)
? expecteds[num_branches - 1]
: expecteds[bi];
ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
error_spec_);
}
}
// Test conditional with two different computations in the true and false cases
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with two different computations in the true and false cases
// that take in the same arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand, CreateR0CeilComputation(), operand,
CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with the same computation in the true and false cases but
// take in different arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
Conditional(pred, operand1, floor, operand2, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with the same computation in the true and false cases that
// take in the same arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
Conditional(pred, operand, floor, operand, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with different instances of the same computation in the true
// and false cases.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
XlaBuilder inner_builder(TestName() + ".inner_conditional");
auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
CreateR0FloorComputation());
auto inner_builder_result = inner_builder.Build();
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Call(&builder, inner_builder_result.ConsumeValueOrDie(),
{pred, operand1, operand2});
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 2 parameters and predicate is
// true.
XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 2 parameters and predicate is
// false.
XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 2 array parameters and
// predicate is true.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
error_spec_);
}
// Test branch computations that take in 2 array parameters.
XLA_TEST_P(CaseOpTest, Parameters2Array) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg =
CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
GetTupleElement(p, 0)),
GetTupleElement(p, 1));
return sb->BuildAndNoteError();
};
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i));
branches_p[i] = &branches[i];
}
Conditional(branch_index, branches_p,
std::vector<XlaOp>(num_branches, operands));
auto modified_bi = static_cast<float>(
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
ComputeAndCompareR1<float>(
&builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
{branch_index_arg.get()}, error_spec_);
}
}
INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest,
::testing::Values(1, 2, 3, 4, 5));
// Test true and false computations that take in 2 array parameters and
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
error_spec_);
}
// Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
ConstantR0<float>(&builder, 25.6f)});
Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
CreateR0TupleFloorComputation());
ComputeAndCompareTuple(
&builder,
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands =
Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
ConstantR1<float>(&builder, {25.6f, 29.2f})});
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
ComputeAndCompareTuple(&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
{pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
// scalar, and an array.
XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
XlaBuilder true_builder(TestName() + ".true");
{
Parameter(&true_builder, 0, empty_tuple_, "tuple");
auto true_pred = ConstantR0<bool>(&true_builder, true);
auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
Tuple(&true_builder, {true_pred, true_scalar, true_array});
}
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
XlaBuilder false_builder(TestName() + ".false");
{
Parameter(&false_builder, 0, empty_tuple_, "tuple");
auto false_pred = ConstantR0<bool>(&false_builder, false);
auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
Tuple(&false_builder, {false_pred, false_scalar, false_array});
}
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<bool>(true),
LiteralUtil::CreateR0<float>(12.2f),
LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
{pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
XlaBuilder true_builder(TestName() + ".true");
{
Parameter(&true_builder, 0, empty_tuple_, "tuple");
auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
Tuple(&true_builder,
{Tuple(&true_builder, {true_constant1, true_constant2}),
Tuple(&true_builder, {true_constant3, true_constant4})});
}
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
XlaBuilder false_builder(TestName() + ".false");
{
Parameter(&false_builder, 0, empty_tuple_, "tuple");
auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
Tuple(&false_builder,
{Tuple(&false_builder, {false_constant1, false_constant2}),
Tuple(&false_builder, {false_constant3, false_constant4})});
}
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(
&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(46.6f),
LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
// Test conditional that takes in scalar operands in the form of external
// params.
XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
XlaBuilder builder(TestName());
XlaOp pred, operand1, operand2;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1_param =
CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
auto operand2_param =
CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
CreateR0FloorComputation());
ComputeAndCompareR0<float>(
&builder, 57.0f,
{pred_arg.get(), operand1_param.get(), operand2_param.get()},
error_spec_);
}
// Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
XlaBuilder builder(TestName());
XlaOp pred, operand1, operand2;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
&builder, &operand1);
auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
&builder, &operand2);
Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
CreateR1FloorComputation());
ComputeAndCompareR1<float>(
&builder, {10.0f, 11.0f},
{pred_arg.get(), operand1_param.get(), operand2_param.get()},
error_spec_);
}
// Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
XlaBuilder inner_builder(TestName() + ".inner_conditional");
{
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
auto pred_cond = GetTupleElement(param0, 0);
auto true_operand = GetTupleElement(param0, 1);
auto false_operand = GetTupleElement(param0, 2);
Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
false_operand, CreateR0FloorComputation());
}
auto inner_builder_result = inner_builder.Build();
EXPECT_IS_OK(inner_builder_result.status());
XlaBuilder builder(TestName());
XlaOp pred1, pred2;
auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
auto operand1 = ConstantR0<float>(&builder, 1.1f);
auto operand2 = ConstantR0<float>(&builder, 12.2f);
auto operand3 = ConstantR0<float>(&builder, 43.3f);
auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
operand3, CreateR0IdentityComputation());
ComputeAndCompareR0<float>(&builder, 12.0f,
{pred1_arg.get(), pred2_arg.get()}, error_spec_);
}
XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
XlaBuilder inner_builder(TestName() + ".inner_conditional");
{
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
auto pred_cond = GetTupleElement(param0, 0);
auto true_operand = GetTupleElement(param0, 1);
auto false_operand = GetTupleElement(param0, 2);
Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
false_operand, CreateR0FloorComputation());
}
auto inner_builder_result = inner_builder.Build();
EXPECT_IS_OK(inner_builder_result.status());
XlaBuilder builder(TestName());
XlaOp pred;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 1.1f);
auto operand2 = ConstantR0<float>(&builder, 12.2f);
auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
XlaBuilder builder(TestName());
auto pred = ConstantR0<bool>(&builder, true);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR0TupleSubComputation());
auto result = builder.Build();
EXPECT_FALSE(result.ok());
EXPECT_THAT(result.status().error_message(),
::testing::HasSubstr("operand 0 must match the shape of the "
"only parameter of branch computation 0"));
}
XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
XlaComputation swapper;
{
XlaBuilder builder(TestName() + ".swapper");
auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
auto x = GetTupleElement(param0, 0);
auto y = GetTupleElement(param0, 1);
Tuple(&builder, {y, x});
swapper = builder.Build().ConsumeValueOrDie();
}
XlaComputation forwarder;
{
XlaBuilder builder(TestName() + ".forwarder");
auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
auto x = GetTupleElement(param0, 0);
auto y = GetTupleElement(param0, 1);
Tuple(&builder, {x, y});
forwarder = builder.Build().ConsumeValueOrDie();
}
XlaComputation main;
{
XlaBuilder builder(TestName() + ".main");
auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
auto x = GetTupleElement(param0, 0);
auto y = GetTupleElement(param0, 1);
auto lt_pred = Lt(x, y);
auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
auto ge_pred = Ge(x, y);
Conditional(ge_pred, res, swapper, res, forwarder);
main = builder.Build().ConsumeValueOrDie();
}
auto test_swap = [&](float a, float b) {
XlaBuilder builder(TestName());
XlaOp x, y;
auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
auto tuple_operand = Tuple(&builder, {x, y});
Call(&builder, main, {tuple_operand});
ComputeAndCompareTuple(
&builder,
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
test_swap(3.11f, 9.4f);
test_swap(11.24f, 5.55f);
}
// Test conditional that duplicates tuple elements in the then and else
// computations. This is a regression test for b/112550242.
XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
const Shape scalar = ShapeUtil::MakeShape(S32, {});
const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar});
XlaComputation then_comp;
{
XlaBuilder builder(TestName() + ".then");
auto p = Parameter(&builder, 0, tuple2, "then.p");
auto e0 = GetTupleElement(p, 0);
auto e1 = GetTupleElement(p, 1);
Tuple(&builder, {e0, e1, e0});
then_comp = builder.Build().ConsumeValueOrDie();
}
XlaComputation else_comp;
{
XlaBuilder builder(TestName() + ".else");
auto p = Parameter(&builder, 0, tuple2, "else.p");
auto e0 = GetTupleElement(p, 0);
auto e1 = GetTupleElement(p, 1);
Tuple(&builder, {e0, e1, e1});
else_comp = builder.Build().ConsumeValueOrDie();
}
{
// Pred is true case.
std::vector<Literal> args;
args.push_back(
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
LiteralUtil::CreateR0<int32>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
Conditional(p_pred, p, then_comp, p, else_comp);
ComputeAndCompare(&builder, args);
}
{
// Pred is false case.
std::vector<Literal> args;
args.push_back(
LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
LiteralUtil::CreateR0<int32>(-42)}));
args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
Conditional(p_pred, p, then_comp, p, else_comp);
ComputeAndCompare(&builder, args);
}
}
} // namespace
} // namespace xla