Modifies kConditional to support both predicated and indexed conditionals (i.e. switch statements).

Updates lowerings for both CPU and GPU.

Adds a new tf2xla kernel for the CFv2 functional Case op which lowers to an indexed kConditional.

PiperOrigin-RevId: 235831207
This commit is contained in:
Brian Patton 2019-02-26 18:06:15 -08:00 committed by TensorFlower Gardener
parent 39b741fd9a
commit 9bac04a4de
37 changed files with 1502 additions and 601 deletions

View File

@ -115,6 +115,7 @@ tf_kernel_library(
],
tags = ["optonly"],
deps = [
":case_op",
":conv_op_helpers",
":if_op",
":while_op",
@ -269,6 +270,23 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "case_op",
srcs = ["case_op.cc"],
hdrs = ["case_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# Kernels that have a dummy (no-op) implementation.
tf_kernel_library(
name = "xla_dummy_ops",

View File

@ -0,0 +1,297 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/case_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
has_token_input_output_ = false;
} else {
has_token_input_output_ = !token_input_nodes_.empty();
}
}
// TODO(b/35949885): There is duplication here with the handling of the
// while_op. Refactor the common code out/rework.
void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaBuilder* b = ctx->builder();
int num_branches = branches_.size();
OP_REQUIRES(ctx, num_branches >= 1,
errors::InvalidArgument("Must provide at least one case branch"));
OP_REQUIRES(ctx, input_type(0) == DT_INT32,
errors::InvalidArgument(
"branch_index argument must be a int32 for XLA compilation"));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)),
errors::InvalidArgument(
"branch_index argument must be scalar for XLA compilation"));
VLOG(1) << "Building Case: " << input_types_.size() << " inputs";
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
int num_resource_args = 0;
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
DataType type = ctx->input_type(i + 1);
if (type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
arg.initialized = resource->initialized();
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind();
arg.type = resource->type();
arg.shape = resource->shape();
OP_REQUIRES(ctx, arg.initialized,
errors::Unimplemented("Uninitialized arguments: ", arg.name));
arg.max_array_size = resource->max_array_size();
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
arg.name = resource->name();
VLOG(2) << "Resource " << resource->name()
<< " type: " << DataTypeString(arg.type)
<< " shape: " << arg.HumanString()
<< " initialized: " << arg.initialized;
num_resource_args++;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_types_[i];
arg.shape = ctx->InputShape(i + 1);
VLOG(2) << "Arg type: " << DataTypeString(arg.type)
<< " shape: " << arg.HumanString();
}
}
// Compile each branch of the conditional.
XlaCompiler::CompileOptions options;
options.use_tuple_arg = true;
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
std::vector<XlaCompiler::CompilationResult> branch_results(num_branches);
std::vector<XlaCompiler::CompilationResult*> branch_results_p(num_branches);
for (int j = 0; j < num_branches; ++j) {
OP_REQUIRES_OK(ctx,
compiler->CompileFunction(options, branches_[j], arguments,
&branch_results[j]));
branch_results_p[j] = &branch_results[j];
}
bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : branch_results_p) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
XlaResource* resource;
OP_REQUIRES_OK(ctx,
ctx->GetResourceInput(update.input_index + 1, &resource));
XlaCompiler::Argument& arg = arguments[update.input_index];
// Add any TensorArray gradients touched by the then/else computation to
// the enclosing graph.
for (const string& grad_source : update.tensor_array_gradients_accessed) {
VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
<< grad_source;
XlaResource* gradient;
OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
grad_source, b, &gradient));
}
// Add all of the TensorArray gradients to the argument. For simplicity,
// we always pass all known gradients.
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
if (!resource->tensor_array_gradients().empty()) {
has_tensor_array_gradients = true;
}
}
}
// Recompile the functions to update the argument shapes for tensor arrays.
if (has_tensor_array_gradients) {
for (int j = 0; j < num_branches; ++j) {
branch_results[j] = {};
OP_REQUIRES_OK(ctx,
compiler->CompileFunction(options, branches_[j], arguments,
&branch_results[j]));
}
}
xla::Shape branch0_input_shape;
std::vector<const xla::XlaComputation*> result_computations(num_branches);
for (int j = 0; j < num_branches; ++j) {
// Check that all branches have identical input shapes.
OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
xla::Shape branch_input_shape = branch_results[j].xla_input_shapes[0];
if (j == 0) {
branch0_input_shape = branch_input_shape;
}
OP_REQUIRES(ctx, branch_input_shape.IsTuple(),
errors::FailedPrecondition("Expected tuple shape"));
OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
OP_REQUIRES(
ctx,
xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape),
errors::InvalidArgument(
"Input shapes of 0 and ", j, " branches do not match: ",
xla::ShapeUtil::HumanString(branch0_input_shape), " vs. ",
xla::ShapeUtil::HumanString(branch_input_shape)));
// Check that all branches have identical output shapes.
OP_REQUIRES(
ctx,
xla::ShapeUtil::Compatible(branch_results[0].xla_output_shape,
branch_results[j].xla_output_shape),
errors::InvalidArgument(
"Output shapes of 0 and ", j, " branches do not match: ",
xla::ShapeUtil::HumanString(branch_results[0].xla_output_shape),
" vs. ",
xla::ShapeUtil::HumanString(branch_results[j].xla_output_shape)));
if (j == 0) {
VLOG(2) << "Input shape: "
<< xla::ShapeUtil::HumanString(branch0_input_shape);
VLOG(2) << "Output shape: "
<< xla::ShapeUtil::HumanString(
branch_results[0].xla_output_shape);
}
// We set return_updated_values_for_all_resources=true and we pass the same
// arguments to both computations, so the resource update count must match.
OP_REQUIRES(ctx,
branch_results[0].resource_updates.size() ==
branch_results[j].resource_updates.size(),
errors::FailedPrecondition(
"Different number of resources in 0 and ", j, " branch"));
for (int i = 0; i < branch_results[0].resource_updates.size(); ++i) {
const auto& lhs = branch_results[0].resource_updates[i];
const auto& rhs = branch_results[j].resource_updates[i];
bool equal = lhs.input_index == rhs.input_index &&
lhs.shape == rhs.shape &&
lhs.tensor_array_gradients_accessed ==
rhs.tensor_array_gradients_accessed;
OP_REQUIRES(ctx, equal,
errors::FailedPrecondition("Mismatch in resource of 0 and ",
j, " branch for resource ", i));
}
result_computations[j] = branch_results[j].computation.get();
}
// Prepare the input arg Tuple.
int num_inputs = branch_results[0].input_mapping.size();
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = branch_results[0].input_mapping[i] + 1;
if (has_token_input_output_ && i == num_inputs - 1) {
// Set token input for this "case" op.
std::vector<xla::XlaOp> token_inputs;
for (const string& node_name : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(node_name);
OP_REQUIRES_OK(ctx, token_or.status());
token_inputs.push_back(token_or.ValueOrDie());
}
inputs[i] = xla::AfterAll(b, token_inputs);
} else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
} else {
inputs[i] = ctx->Input(i + 1);
}
}
auto input_tuple = xla::Tuple(b, inputs);
xla::XlaOp outputs =
xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations),
std::vector<xla::XlaOp>(num_branches, input_tuple));
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
auto shape_or = b->GetShape(output_handle);
if (shape_or.ok()) {
LOG(INFO) << "Shape for output " << i << ": "
<< xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
} else {
LOG(INFO) << "Shape unknown for output " << i;
}
}
ctx->SetOutput(i, output_handle);
}
if (has_token_input_output_) {
// Set token output for this "Case" op. Token output is the last output of
// XLA computation, which comes after all "normal" TF outputs and resource
// updates. For "Case" node, num of resource updates equals to number of
// resource args because we set `return_updated_values_for_all_resources`
// to true in XlaCompiler option.
xla::XlaOp token_output =
xla::GetTupleElement(outputs, output_types_.size() + num_resource_args);
auto shape_or = b->GetShape(token_output);
OP_REQUIRES_OK(ctx, shape_or.status());
OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(),
errors::FailedPrecondition(
"Token output is not token type: ",
xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
}
// Updates the values of any resource variables modified by the conditional
// bodies.
for (const XlaCompiler::CompilationResult& result : branch_results) {
for (int i = 0; i < result.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& update = result.resource_updates[i];
XlaResource* resource;
OP_REQUIRES_OK(ctx,
ctx->GetResourceInput(update.input_index + 1, &resource));
if (update.modified) {
int pos = static_cast<int>(result.outputs.size()) + i;
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
xla::GetTupleElement(outputs, pos), b));
}
VLOG(2) << "Case variable: pos: " << update.input_index
<< " name: " << resource->name()
<< " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
<< " shape: " << update.shape.DebugString();
}
}
VLOG(1) << "Done building Case";
}
REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp);
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_
#include <string>
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
// This TensorFlow op provides a functional switch/case primitive.
//
// The outputs of the branches must agree on the number, types, and
// shapes of the Tensors carried around the two bodies.
//
// Computations in branch bodies may read from and write to resource variables.
// Resource variables may be passed as arguments to the branch function's
// bodies. The XlaCompiler converts resource variable arguments
// into parameters to the XLA computation and moves them to the end of the
// parameter list, and by using the `return_updated_values_for_all_variables`
// we ensure that all variables that appear in the input also appear at the
// end of the branch bodies output. This ensures the branch bodies output
// signatures match.
//
// It is the user's responsibility to ensure that each non-variable _Arg matches
// the corresponding _Retval.
class XlaCaseOp : public XlaOpKernel {
public:
explicit XlaCaseOp(OpKernelConstruction* ctx);
void Compile(XlaOpKernelContext* ctx) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp);
std::vector<NameAttrList> branches_;
DataTypeVector input_types_;
DataTypeVector output_types_;
bool has_token_input_output_;
std::vector<string> token_input_nodes_;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_

View File

@ -1880,32 +1880,46 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation) {
// The index of true_computation must be 0 and that of false computation
// must be 1.
return Conditional(predicate, {&true_computation, &false_computation},
{true_operand, false_operand});
}
XlaOp XlaBuilder::Conditional(
const XlaOp& branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate));
TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape,
GetShape(true_operand));
TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape,
true_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape,
GetShape(false_operand));
TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape,
false_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferConditionalShape(
predicate_shape, true_operand_shape, false_operand_shape,
true_computation_shape, false_computation_shape));
TF_ASSIGN_OR_RETURN(const Shape& branch_index_shape,
GetShape(branch_index));
std::vector<Shape> branch_operand_shapes(branch_operands.size());
std::vector<ProgramShape> branch_computation_shapes(
branch_computations.size());
for (int j = 0; j < branch_operands.size(); ++j) {
TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
GetShape(branch_operands[j]));
TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
branch_computations[j]->GetProgramShape());
}
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferConditionalShape(
branch_index_shape, branch_computation_shapes,
branch_operand_shapes));
*instr.mutable_shape() = shape.ToProto();
// The index of true_computation must be 0 and that of false computation
// must be 1.
AddCalledComputation(true_computation, &instr);
AddCalledComputation(false_computation, &instr);
for (const XlaComputation* branch_computation : branch_computations) {
AddCalledComputation(*branch_computation, &instr);
}
std::vector<XlaOp> operands(1, branch_index);
for (const XlaOp branch_operand : branch_operands) {
operands.emplace_back(branch_operand);
}
return AddInstruction(std::move(instr), HloOpcode::kConditional,
{predicate, true_operand, false_operand});
absl::MakeSpan(operands));
});
}
@ -3381,6 +3395,13 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
false_computation);
}
XlaOp Conditional(const XlaOp& branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands) {
return branch_index.builder()->Conditional(branch_index, branch_computations,
branch_operands);
}
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits) {
return operand.builder()->ReducePrecision(operand, exponent_bits,

View File

@ -529,6 +529,10 @@ class XlaBuilder {
const XlaOp& false_operand,
const XlaComputation& false_computation);
XlaOp Conditional(const XlaOp& branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
@ -946,6 +950,10 @@ class XlaBuilder {
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation);
friend XlaOp Conditional(
const XlaOp& branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
@ -1776,6 +1784,15 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaOp& false_operand,
const XlaComputation& false_computation);
// Enqueues either a predicated (if/else) or indexed (switch/case/default)
// conditional node onto the computation. N >= 1 branch_computations and
// branch_operands are matched by index. branch_index selects the branch that
// will be executed. Out of range branch_index uses the N-1'th
// branch_computation as default.
XlaOp Conditional(const XlaOp& branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
// Enqueues a ReducePrecision node onto the computation.
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);

View File

@ -541,25 +541,49 @@ See also
false_computation)` </b>
Arguments | Type | Semantics
------------------- | ---------------- | ---------------------------------
------------------- | ---------------- | --------------------------------------
`pred` | `XlaOp` | Scalar of type `PRED`
`true_operand` | `XlaOp` | Argument of type `T_0`
`true_computation` | `XlaComputation` | XlaComputation of type `T_0 -> S`
`false_operand` | `XlaOp` | Argument of type `T_1`
`false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S`
`true_operand` | `XlaOp` | Argument of type $$ T_0 $$
`true_computation` | `XlaComputation` | XlaComputation of type $$ T_0 \to S$$
`false_operand` | `XlaOp` | Argument of type $$ T_1 $$
`false_computation` | `XlaComputation` | XlaComputation of type $$ T_1 \to S $$
Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
is `false`, and returns the result.
The `true_computation` must take in a single argument of type `T_0` and will be
invoked with `true_operand` which must be of the same type. The
`false_computation` must take in a single argument of type `T_1` and will be
The `true_computation` must take in a single argument of type $$ T_0 $$ and will
be invoked with `true_operand` which must be of the same type. The
`false_computation` must take in a single argument of type $$ T_1 $$ and will be
invoked with `false_operand` which must be of the same type. The type of the
returned value of `true_computation` and `false_computation` must be the same.
Note that only one of `true_computation` and `false_computation` will be
executed depending on the value of `pred`.
<b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
| Arguments | Type | Semantics |
| --------------------- | --------------------- | ---------------------------- |
| `branch_index` | `XlaOp` | Scalar of type `PRED` or |
: : : `S32` :
| `branch_computations` | sequence of N | XlaComputations of type $$ |
: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., :
: : : T_{N-1} \to S $$ :
| `branch_operands` | sequence of N `XlaOp` | Arguments of type $$ T_0 , |
: : : T_1 , ..., T_{N-1} $$ :
Executes `branch_computations[branch_index]`, and returns the result. If
`branch_index` is a `PRED`, then the `true` branch is in position 0 and the
`false` branch is in position 1. If `branch_index` is an `S32` which is < 0
or >= N, then `branch_computations[N-1]` is executed as the default branch.
Each `branch_computations[b]` must take in a single argument of type `T_b` and
will be invoked with `branch_operands[b]` which must be of the same type. The
type of the returned value of each `branch_computations[b]` must be the same.
Note that only one of the `branch_computations` will be executed depending on
the value of `branch_index`.
## Conv (convolution)
See also

View File

@ -3785,6 +3785,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
],
)

View File

@ -1620,62 +1620,46 @@ void BufferAssigner::BuildColocatedBufferSets(
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
} else if (opcode == HloOpcode::kConditional) {
const HloInstruction* conditional_hlo = instruction;
const HloInstruction* conditional = instruction;
ShapeUtil::ForEachSubshape(
conditional_hlo->shape(),
[this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
conditional->shape(),
[this, conditional, &points_to_analysis, colocated_buffer_sets](
const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add conditional.result.
AddBufferToColocatedSet(conditional_hlo, index,
points_to_analysis, &colocated_set);
// Add conditional.true_computation.root.
AddBufferToColocatedSet(
conditional_hlo->true_computation()->root_instruction(),
index, points_to_analysis, &colocated_set);
// Add conditional.false_computation.root.
AddBufferToColocatedSet(
conditional_hlo->false_computation()->root_instruction(),
index, points_to_analysis, &colocated_set);
// Add cond.result.
AddBufferToColocatedSet(conditional, index, points_to_analysis,
&colocated_set);
for (int j = 0; j < conditional->branch_count(); ++j) {
// Add each cond.branch_computation[j].root.
AddBufferToColocatedSet(
conditional->branch_computation(j)->root_instruction(),
index, points_to_analysis, &colocated_set);
}
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
// Add true_operand and conditional.true_computation.parameter(0) as a
// colocated buffer set. Note that this has to be done for each subshape
// in the true_operand of the conditional.
ShapeUtil::ForEachSubshape(
conditional_hlo->operand(1)->shape(),
[this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const LogicalBuffer*> true_set;
// Add conditional.true_operand.
AddBufferToColocatedSet(conditional_hlo->operand(1), index,
points_to_analysis, &true_set);
// Add conditional.true_computation.parameter_instruction(0).
AddBufferToColocatedSet(
conditional_hlo->true_computation()->parameter_instruction(0),
index, points_to_analysis, &true_set);
AddSetToColocatedBufferSets(true_set, colocated_buffer_sets);
});
// Add false_operand and conditional.false_computation.parameter(0) as a
// colocated buffer set. Note that this has to be done for each subshape
// in the false_operand of the conditional.
ShapeUtil::ForEachSubshape(
conditional_hlo->operand(2)->shape(),
[this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const LogicalBuffer*> false_set;
// Add conditional.false_operand.
AddBufferToColocatedSet(conditional_hlo->operand(2), index,
points_to_analysis, &false_set);
// Add conditional.false_computation.parameter_instruction(0).
AddBufferToColocatedSet(
conditional_hlo->false_computation()->parameter_instruction(
0),
index, points_to_analysis, &false_set);
AddSetToColocatedBufferSets(false_set, colocated_buffer_sets);
});
for (int j = 0; j < conditional->branch_count(); ++j) {
// Add branch_operand[j] (which is operand[j+1]) and
// cond.branch_computation[j].parameter(0) as a colocated
// buffer set. Note that this has to be done for each subshape in the
// branch_operand of the case.
ShapeUtil::ForEachSubshape(
conditional->operand(j + 1)->shape(),
[this, j, conditional, &points_to_analysis,
colocated_buffer_sets](const Shape& /*subshape*/,
const ShapeIndex& index) {
std::vector<const LogicalBuffer*> branch_set;
// Add cond.operand[j+1].
AddBufferToColocatedSet(conditional->operand(j + 1), index,
points_to_analysis, &branch_set);
// Add cond.branch_computation[j].parameter_instruction(0).
AddBufferToColocatedSet(
conditional->branch_computation(j)->parameter_instruction(
0),
index, points_to_analysis, &branch_set);
AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets);
});
}
}
}
}

View File

@ -33,8 +33,8 @@ limitations under the License.
namespace xla {
// Tries to replace a conditional with a call operation of the corresponding
// computation. If the given conditional has a constant predicate, tries to
// replace it with a call to its true/false computation as appropriate and then
// computation. If the given conditional has a constant branch_index, tries to
// replace it with a call to its corresponding branch computation and then
// inline that computation.
//
// Returns true if it made a change to the graph.
@ -50,24 +50,30 @@ static StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
return false;
}
if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
VLOG(2) << "Not attempting to remove conditional as its predicate is not a "
"compile-time constant: "
<< conditional->ToShortString();
return false;
}
// We can always inline a 1-branch conditional due to default branch fallback.
int branch_index = 0;
if (conditional->branch_count() > 1) {
if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
VLOG(2) << "Not attempting to remove conditional as its branch_index is "
"not a compile-time constant: "
<< conditional->ToShortString();
return false;
}
if (conditional->operand(0)->shape().element_type() == PRED) {
branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
} else {
branch_index = conditional->operand(0)->literal().Get<int32>({});
if (branch_index < 0 || branch_index >= conditional->branch_count()) {
branch_index = conditional->branch_count() - 1;
}
}
}
auto computation = conditional->parent();
HloInstruction* call_op;
if (conditional->operand(0)->literal().Get<bool>({})) {
call_op = computation->AddInstruction(HloInstruction::CreateCall(
conditional->shape(), {conditional->mutable_operand(1)},
conditional->true_computation()));
} else {
call_op = computation->AddInstruction(HloInstruction::CreateCall(
conditional->shape(), {conditional->mutable_operand(2)},
conditional->false_computation()));
}
call_op = computation->AddInstruction(HloInstruction::CreateCall(
conditional->shape(), {conditional->mutable_operand(branch_index + 1)},
conditional->branch_computation(branch_index)));
conditional->SetupDerivedInstruction(call_op);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());

View File

@ -319,8 +319,7 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
<< conditional->name();
TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
for (HloComputation* computation :
{conditional->true_computation(), conditional->false_computation()}) {
for (HloComputation* computation : conditional->branch_computations()) {
HloInstruction* root = computation->root_instruction();
std::vector<HloInstruction*> users = root->users();
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,

View File

@ -2515,53 +2515,109 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
}
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
<< ShapeUtil::HumanString(pred->shape());
auto branch_index = conditional->operand(0);
int num_branches = conditional->branch_count();
TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
(branch_index->shape().element_type() == PRED ||
branch_index->shape().element_type() == S32))
<< "Branch index on a conditional must be scalar bool or int32; got: "
<< ShapeUtil::HumanString(branch_index->shape());
HloComputation* true_computation = conditional->true_computation();
HloComputation* false_computation = conditional->false_computation();
TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
true_computation->root_instruction()->shape()))
<< "Shape of conditional should be same as the shape of the true "
<< "computation; got: " << ShapeUtil::HumanString(conditional->shape())
<< " and "
<< ShapeUtil::HumanString(true_computation->root_instruction()->shape());
TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
false_computation->root_instruction()->shape()))
<< "Shape of conditional should be same as the shape of the false "
<< "computation; got: " << ShapeUtil::HumanString(conditional->shape())
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
for (int b = 0; b < num_branches; ++b) {
HloComputation* br_computation = conditional->branch_computation(b);
TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
br_computation->root_instruction()->shape()))
<< "Shape of conditional should be same as the shape of the " << b
<< "th branch computation; got: "
<< ShapeUtil::HumanString(conditional->shape()) << " and "
<< ShapeUtil::HumanString(br_computation->root_instruction()->shape());
}
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
// Generating:
// if (pred)
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
llvm::LoadInst* pred_value =
Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
llvm::Value* pred_cond = ICmpNE(
pred_value,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
if (branch_index->shape().element_type() == PRED) {
// Emit an if-else to LLVM:
// if (pred)
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
llvm::LoadInst* pred_value = Load(
GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
llvm::Value* pred_cond =
ICmpNE(pred_value,
llvm::ConstantInt::get(
llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
EmitGlobalCall(*conditional->true_computation(),
IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.true_block, &b_);
EmitGlobalCall(*conditional->branch_computation(0),
IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.false_block, &b_);
EmitGlobalCall(*conditional->false_computation(),
IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.false_block, &b_);
EmitGlobalCall(*conditional->branch_computation(1),
IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.after_block, &b_);
SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
}
// We emit a switch statement to LLVM:
// switch (branch_index) {
// default:
// result = branch_computations[num_branches-1](operands[num_branches-1]);
// break;
// case 0:
// result = branch_computations[0](operands[0]); break;
// case 1:
// result = branch_computations[1](operands[1]); break;
// ...
// case [[num_branches-2]]:
// result = branch_computations[num_branches-2](operands[num_branches-2]);
// break;
// }
llvm::LoadInst* branch_index_value = Load(
GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
auto case_block = b_.GetInsertBlock();
llvm::BasicBlock* after_block;
// Add a terminator to the case block, if necessary.
if (case_block->getTerminator() == nullptr) {
after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
b_.SetInsertPoint(case_block);
b_.CreateBr(after_block);
} else {
after_block =
case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
}
// Our basic block should now end with an unconditional branch. Remove it;
// we're going to replace it with a switch based branch.
case_block->getTerminator()->eraseFromParent();
// Lower the default branch computation.
auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
b_.SetInsertPoint(default_block);
EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
IrName(conditional, "_default"));
b_.CreateBr(after_block);
// Prepare the switch (branch_index) { ... } instruction.
b_.SetInsertPoint(case_block);
llvm::SwitchInst* case_inst =
b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
// Lower each branch's computation.
for (int b = 0; b < num_branches - 1; ++b) { // last branch is default
// Lower the case b: { ... ; break; } computation.
auto branch_block =
llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
b_.SetInsertPoint(branch_block);
EmitGlobalCall(*conditional->branch_computation(b),
IrName(conditional, absl::StrCat("_branch", b)));
b_.CreateBr(after_block);
case_inst->addCase(b_.getInt32(b), branch_block);
}
SetToFirstInsertPoint(after_block, &b_);
return Status::OK();
}

View File

@ -26,7 +26,7 @@ namespace xla {
namespace {
// Helper to replace the called computation at a while-, call-, or
// Helper to replace the called computation at a while-, call-, case-, or
// conditional-instruction. This function replaces exactly one instance of
// 'computation' with 'new_computation' even if 'instruction' calls
// 'computation' more than once.
@ -49,11 +49,14 @@ void ReplaceCalledComputation(HloInstruction* instruction,
break;
}
case HloOpcode::kConditional: {
if (computation == instruction->true_computation()) {
instruction->set_true_computation(new_computation);
} else {
CHECK_EQ(computation, instruction->false_computation());
instruction->set_false_computation(new_computation);
for (int b = 0; b < instruction->branch_count(); ++b) {
if (b == instruction->branch_count() - 1) {
CHECK_EQ(computation, instruction->branch_computation(b));
}
if (computation == instruction->branch_computation(b)) {
instruction->set_branch_computation(b, new_computation);
break;
}
}
break;
}

View File

@ -24,25 +24,35 @@ namespace xla {
namespace gpu {
ConditionalThunk::ConditionalThunk(
const BufferAllocation::Slice& predicate_buffer_index,
const BufferAllocation::Slice& true_operand_buffer_index,
const BufferAllocation::Slice& false_operand_buffer_index,
ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence,
const BufferAllocation::Slice& branch_index_buffer_index,
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
std::vector<ThunkSequence> branch_thunk_sequences,
const HloInstruction* hlo)
: Thunk(Kind::kConditional, hlo),
predicate_buffer_index_(predicate_buffer_index),
true_operand_buffer_index_(true_operand_buffer_index),
false_operand_buffer_index_(false_operand_buffer_index),
// Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_
// constructors because these SequentialThunks are logically "part of"
// this ConditionalThunk, and shouldn't be profiled separately from it.
true_thunk_(std::move(true_thunk_sequence), nullptr),
false_thunk_(std::move(false_thunk_sequence), nullptr) {}
branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED),
branch_index_buffer_index_(branch_index_buffer_index),
branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(),
branch_operand_buffer_indexes.end()) {
// Pass nullptr as the HloInstruction* to the branch_thunks_
// constructors because these SequentialThunks are logically "part of"
// this ConditionalThunk, and shouldn't be profiled separately from it.
branch_thunks_.reserve(branch_thunk_sequences.size());
for (auto& branch_thunk_sequence : branch_thunk_sequences) {
branch_thunks_.emplace_back(
new SequentialThunk(std::move(branch_thunk_sequence), nullptr));
}
}
Status ConditionalThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor));
TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor));
if (branch_index_is_bool_) {
TF_RET_CHECK(branch_thunks_.size() == 2);
} else {
TF_RET_CHECK(!branch_thunks_.empty());
}
for (auto& branch_thunk : branch_thunks_) {
TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor));
}
return Status::OK();
}
@ -51,30 +61,37 @@ Status ConditionalThunk::ExecuteOnStream(
HloExecutionProfiler* profiler) {
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
// Copy the predicate value from device.
bool predicate;
se::DeviceMemoryBase predicate_address =
buffer_allocations.GetDeviceAddress(predicate_buffer_index_);
stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool));
int32 branch_index = -1;
bool pred = false;
se::DeviceMemoryBase branch_index_address =
buffer_allocations.GetDeviceAddress(branch_index_buffer_index_);
if (branch_index_is_bool_) {
stream->ThenMemcpy(&pred, branch_index_address, sizeof(bool));
} else {
stream->ThenMemcpy(&branch_index, branch_index_address, sizeof(int32));
}
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to retrieve predicate value on stream %p: %s.",
stream, block_status.error_message());
return InternalError(
"Failed to retrieve branch_index value on stream %p: %s.", stream,
block_status.error_message());
}
if (branch_index_is_bool_) {
branch_index = pred ? 0 : 1;
} else {
// Handle default scenario for branch_index not in [0, num_branches).
if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) {
branch_index = hlo_instruction()->branch_count() - 1;
}
}
// Execute the true or the false computation depending on the value of the
// predicate.
if (predicate) {
profiler->StartHloComputation();
TF_RETURN_IF_ERROR(
true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
profiler->FinishHloComputation(hlo_instruction()->true_computation());
} else {
profiler->StartHloComputation();
TF_RETURN_IF_ERROR(
false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
profiler->FinishHloComputation(hlo_instruction()->false_computation());
}
// Execute the branch computation corresponding to the value of branch_index.
profiler->StartHloComputation();
TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(
buffer_allocations, stream, profiler));
profiler->FinishHloComputation(
hlo_instruction()->branch_computation(branch_index));
return Status::OK();
}

View File

@ -16,6 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
@ -38,12 +42,11 @@ namespace gpu {
// false computation share the same allocation.
class ConditionalThunk : public Thunk {
public:
ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index,
const BufferAllocation::Slice& true_operand_buffer_index,
const BufferAllocation::Slice& false_operand_buffer_index,
ThunkSequence true_thunk_sequence,
ThunkSequence false_thunk_sequence,
const HloInstruction* hlo);
ConditionalThunk(
const BufferAllocation::Slice& branch_index_buffer_index,
absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes,
std::vector<ThunkSequence> branch_thunk_sequences,
const HloInstruction* hlo);
ConditionalThunk(const ConditionalThunk&) = delete;
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
@ -55,11 +58,10 @@ class ConditionalThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice predicate_buffer_index_;
BufferAllocation::Slice true_operand_buffer_index_;
BufferAllocation::Slice false_operand_buffer_index_;
SequentialThunk true_thunk_;
SequentialThunk false_thunk_;
const bool branch_index_is_bool_;
BufferAllocation::Slice branch_index_buffer_index_;
std::vector<BufferAllocation::Slice> branch_operand_buffer_indexes_;
std::vector<std::unique_ptr<SequentialThunk>> branch_thunks_;
};
} // namespace gpu

View File

@ -2024,41 +2024,32 @@ Status CheckWhileBuffersShareAllocation(
// Checks that the buffers used in a conditional instruction are shared with the
// operands and result as follows:
// * The result buffer of the conditional should share the allocation with the
// result buffers of the true and false computations.
// * The buffer of operand 1 should share the allocation with the buffer of
// the parameter 0 instruction of the true computation.
// * The buffer of operand 2 should share the allocation with the buffer of
// the parameter 0 instruction of the false computation.
// result buffers of each branch computation.
// * The buffer of operand b+1 should share the allocation with the buffer of
// the parameter 0 instruction of the b'th computation.
Status CheckConditionalBuffersShareAllocation(
const HloInstruction* conditional,
const BufferAssignment& buffer_assignment) {
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
conditional->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
conditional, conditional->true_computation()->root_instruction(),
index, buffer_assignment));
TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
conditional, conditional->false_computation()->root_instruction(),
index, buffer_assignment));
for (auto branch_computation : conditional->branch_computations()) {
TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
conditional, branch_computation->root_instruction(), index,
buffer_assignment));
}
return Status::OK();
}));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
conditional->operand(1)->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
return CheckHloBuffersShareAllocation(
conditional->operand(1),
conditional->true_computation()->parameter_instruction(0), index,
buffer_assignment);
}));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
conditional->operand(2)->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
return CheckHloBuffersShareAllocation(
conditional->operand(2),
conditional->false_computation()->parameter_instruction(0), index,
buffer_assignment);
}));
for (int j = 0; j < conditional->branch_count(); ++j) {
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
conditional->operand(j + 1)->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
return CheckHloBuffersShareAllocation(
conditional->operand(j + 1),
conditional->branch_computation(j)->parameter_instruction(0),
index, buffer_assignment);
}));
}
return Status::OK();
}
@ -2111,22 +2102,20 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
hlo, ir_emitter_context_->buffer_assignment()));
HloComputation* true_computation = hlo->true_computation();
IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation,
ir_emitter_context_);
TF_CHECK_OK(true_computation->Accept(&ir_emitter_true));
HloComputation* false_computation = hlo->false_computation();
IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation,
ir_emitter_context_);
TF_CHECK_OK(false_computation->Accept(&ir_emitter_false));
std::vector<BufferAllocation::Slice> branch_operands;
std::vector<ThunkSequence> branch_thunks;
for (int j = 0; j < hlo->branch_count(); ++j) {
branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
HloComputation* branch_computation = hlo->branch_computation(j);
IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
ir_emitter_context_);
TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
}
return absl::make_unique<ConditionalThunk>(
GetAllocationSlice(*hlo->operand(0)),
GetAllocationSlice(*hlo->operand(1)),
GetAllocationSlice(*hlo->operand(2)),
std::move(*ir_emitter_true.ConsumeThunkSequence()),
std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo);
GetAllocationSlice(*hlo->operand(0)), branch_operands,
std::move(branch_thunks), hlo);
}
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(

View File

@ -356,9 +356,9 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
const int64 loop_limit);
// Returns a ConditionalThunk that executes the thunk sequence for
// 'true_computation' or 'false_computation' depending on the value of the
// predicate in the given conditional instruction.
// Returns a ConditionalThunk which executes the thunk sequence for the
// 'branch_computation' corresponding to the predicate/branch_index of the
// given conditional instruction.
std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
Status Postprocess(HloInstruction* hlo) override;

View File

@ -293,7 +293,7 @@ class BufferValueMap {
VLOG(3)
<< " value @ " << position << " is root of "
<< callsite.instruction()->name()
<< "; true/false branch roots must share buffer among them : "
<< "; branch computation roots must share buffer among them : "
<< cond_value.ToShortString();
aliased_buffers->push_back(GetBufferForValue(cond_value));
}

View File

@ -684,19 +684,22 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
}
Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
// Compute the cost of the true and false computations and take the maximum
// from those for each property.
// Compute the cost of the branch computations and take the maximum from those
// for each property.
TF_ASSIGN_OR_RETURN(
const Properties true_computation_properties,
ProcessUnnestedSubcomputation(conditional->true_computation()));
TF_ASSIGN_OR_RETURN(
const Properties false_computation_properties,
ProcessUnnestedSubcomputation(conditional->false_computation()));
current_properties_ = true_computation_properties;
for (const auto& property : false_computation_properties) {
if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_, property)) {
current_properties_[property.first] =
std::max(current_properties_[property.first], property.second);
const Properties branch0_computation_properties,
ProcessUnnestedSubcomputation(conditional->branch_computation(0)));
current_properties_ = branch0_computation_properties;
for (int j = 1; j < conditional->branch_count(); ++j) {
TF_ASSIGN_OR_RETURN(
const Properties branch_computation_properties,
ProcessUnnestedSubcomputation(conditional->branch_computation(j)));
for (const auto& property : branch_computation_properties) {
if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_,
property)) {
auto& current_property = current_properties_[property.first];
current_property = std::max(current_property, property.second);
}
}
}
current_should_compute_bottleneck_time_ = false;

View File

@ -414,11 +414,11 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
bool HloDataflowAnalysis::UpdateConditionalValueSet(
HloInstruction* conditional) {
CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
const InstructionValueSet* const inputs[] = {
&GetInstructionValueSet(
conditional->true_computation()->root_instruction()),
&GetInstructionValueSet(
conditional->false_computation()->root_instruction())};
std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
for (int j = 0; j < conditional->branch_count(); ++j) {
inputs[j] = &GetInstructionValueSet(
conditional->branch_computation(j)->root_instruction());
}
if (ssa_form_) {
return Phi(conditional, inputs);
} else {
@ -546,20 +546,23 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
} else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
CHECK_EQ(parameter->parameter_number(), 0);
auto conditional = callsite.instruction();
// Conditional has 3 operands. Operand 0 is the predicate, operand 1 is
// the argument to the true computation and operand 2 is the argument to
// the false computation.
// Conditional has branch_count+1 operands. Operand 0 is the branch_index,
// operands 1 and onward are the arguments to the branch computations.
//
// If the parameter belongs to conditional's true computation, then
// If the parameter belongs to conditional's branch 0 computation, then
// operand 1 is forwarded to this parameter instruction. If the parameter
// belongs to conditional's false computation, then operand 2 is forwarded
// to this parameter instruction.
if (parameter->parent() == conditional->true_computation()) {
inputs.push_back(&GetInstructionValueSet(conditional->operand(1)));
} else {
CHECK_EQ(parameter->parent(), conditional->false_computation());
inputs.push_back(&GetInstructionValueSet(conditional->operand(2)));
// belongs to conditional's branch 5 computation, then operand 6 is
// forwarded to this parameter instruction.
bool found_parent = false;
for (int j = 0; j < conditional->branch_count(); ++j) {
if (parameter->parent() == conditional->branch_computation(j)) {
inputs.push_back(
&GetInstructionValueSet(conditional->operand(j + 1)));
found_parent = true;
break;
}
}
CHECK(found_parent);
need_phi = true;
} else {
LOG(FATAL) << "CallContext::kSequential computations should only be "
@ -710,19 +713,17 @@ void HloDataflowAnalysis::Propagate() {
// parameter(s) of the computation need to be updated.
if (user->opcode() == HloOpcode::kConditional) {
// If operand 0 is the use of instruction, then no parameters need to be
// updated, since that is the predicate of the conditional.
// If operand 1 is the use of instruction, then the true_computation's
// parameter need to be updated.
// If operand 2 is the use of instruction, then the false_computation's
// parameter need to be updated.
// updated, since that is the branch_index of the conditional.
// If operand n+1 is the use of instruction, then the branch_computation
// n's parameter need to be updated.
//
// Note that the same instruction can be used in both operand 1 and
// operand 2.
if (user->operand(1) == instruction) {
add_to_worklist(user->true_computation()->parameter_instruction(0));
}
if (user->operand(2) == instruction) {
add_to_worklist(user->false_computation()->parameter_instruction(0));
// Note that the same instruction can be used in multiple branches'
// operands.
for (int j = 0; j < user->branch_count(); ++j) {
if (user->operand(j + 1) == instruction) {
add_to_worklist(
user->branch_computation(j)->parameter_instruction(0));
}
}
} else {
for (HloComputation* called_computation : user->called_computations()) {
@ -744,8 +745,8 @@ void HloDataflowAnalysis::Propagate() {
const CallGraphNode& call_graph_node =
call_graph_->GetNode(instruction->parent());
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if ((callsite.instruction()->opcode() == HloOpcode::kCall) ||
(callsite.instruction()->opcode() == HloOpcode::kConditional)) {
if (callsite.instruction()->opcode() == HloOpcode::kCall ||
callsite.instruction()->opcode() == HloOpcode::kConditional) {
add_to_worklist(callsite.instruction());
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Add the while itself, and the body and condition parameters.

View File

@ -1270,28 +1270,27 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0));
const auto& true_computation_arg =
GetEvaluatedLiteralFor(conditional->operand(1));
const auto& false_computation_arg =
GetEvaluatedLiteralFor(conditional->operand(2));
auto* true_computation = conditional->true_computation();
auto* false_computation = conditional->false_computation();
const auto& branch_index_literal =
GetEvaluatedLiteralFor(conditional->operand(0));
int branch_index;
if (conditional->operand(0)->shape().element_type() == PRED) {
branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
} else {
branch_index = branch_index_literal.Get<int32>({});
if (branch_index < 0 || branch_index >= conditional->branch_count()) {
branch_index = conditional->branch_count() - 1;
}
}
const auto& branch_computation_arg =
GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
HloEvaluator embedded_evaluator;
embedded_evaluator.set_dynamic_dimension_inference(
dynamic_dimension_inference_);
Literal result;
if (pred.Get<bool>({})) {
result =
embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg})
.ConsumeValueOrDie();
} else {
result = embedded_evaluator
.Evaluate(*false_computation, {&false_computation_arg})
.ConsumeValueOrDie();
}
Literal result = embedded_evaluator
.Evaluate(*conditional->branch_computation(branch_index),
{&branch_computation_arg})
.ConsumeValueOrDie();
evaluated_[conditional] = std::move(result);
return Status::OK();

View File

@ -82,6 +82,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const auto computations = [&computation_map, &proto](int index) {
return computation_map.at(proto.called_computation_ids(index));
};
const auto all_computations = [&computation_map, &proto]() {
std::vector<HloComputation*> result(proto.called_computation_ids_size());
std::transform(proto.called_computation_ids().begin(),
proto.called_computation_ids().end(), result.begin(),
[&computation_map](int64 computation_id) {
return computation_map.at(computation_id);
});
return result;
};
TF_RET_CHECK(
absl::c_all_of(proto.operand_ids(),
@ -163,6 +172,26 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction =
CreateConcatenate(shape, all_operands(), proto.dimensions(0));
break;
case HloOpcode::kConditional: {
TF_RET_CHECK(proto.called_computation_ids_size() > 0)
<< "conditional should have at least 1 called computation";
if (operands(0)->shape().element_type() == PRED) {
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
<< "conditional should have exactly 2 called computations but got "
<< proto.called_computation_ids_size();
}
TF_RET_CHECK(proto.operand_ids_size() ==
proto.called_computation_ids_size() + 1)
<< "conditional should have one branch_index operand plus one "
"operand per called computation but got "
<< proto.operand_ids_size() << " operands for "
<< proto.called_computation_ids_size() << " branch computations";
auto cond_operands = all_operands();
instruction =
CreateConditional(shape, cond_operands[0], all_computations(),
absl::MakeSpan(cond_operands).subspan(1));
break;
}
case HloOpcode::kReduce:
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
<< "Reduce instruction should have an even number of operands but "
@ -898,6 +927,21 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand,
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
const Shape& shape, HloInstruction* branch_index,
absl::Span<HloComputation* const> branch_computations,
absl::Span<HloInstruction* const> branch_computation_args) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(branch_index);
CHECK_EQ(branch_computations.size(), branch_computation_args.size());
for (int i = 0; i < branch_computations.size(); ++i) {
instruction->called_computations_.push_back(branch_computations[i]);
instruction->AppendOperand(branch_computation_args[i]);
}
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> start_indices,
@ -1397,10 +1441,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), 3);
clone = CreateConditional(shape, new_operands[0], new_operands[1],
true_computation(), new_operands[2],
false_computation());
CHECK_EQ(new_operands.size(), branch_count() + 1);
clone = CreateConditional(shape, new_operands[0],
absl::MakeSpan(branch_computations()),
new_operands.subspan(1));
break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
@ -1711,16 +1755,16 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kConditional:
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
case HloOpcode::kWhile: {
if (eq_computations(while_body(), other.while_body()) &&
eq_computations(while_condition(), other.while_condition())) {
return true;
for (int j = 0; j < branch_count(); ++j) {
if (!eq_computations(branch_computation(j),
other.branch_computation(j))) {
return false;
}
}
return false;
}
return true;
case HloOpcode::kWhile:
return (eq_computations(while_body(), other.while_body()) &&
eq_computations(while_condition(), other.while_condition()));
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
@ -1983,28 +2027,41 @@ HloInstruction* HloInstruction::while_init() const {
HloComputation* HloInstruction::true_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
CHECK_EQ(PRED, operand(0)->shape().element_type());
return called_computations_[kTrueComputationIndex];
}
HloComputation* HloInstruction::false_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
CHECK_EQ(PRED, operand(0)->shape().element_type());
return called_computations_[kFalseComputationIndex];
}
void HloInstruction::set_true_computation(HloComputation* true_computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[kTrueComputationIndex] = true_computation;
const std::vector<HloComputation*>& HloInstruction::branch_computations()
const {
CHECK(HloOpcode::kConditional == opcode_);
return called_computations_;
}
void HloInstruction::set_false_computation(HloComputation* false_computation) {
int HloInstruction::branch_count() const {
CHECK(HloOpcode::kConditional == opcode_);
return called_computations_.size();
}
HloComputation* HloInstruction::branch_computation(int b) const {
CHECK(HloOpcode::kConditional == opcode_);
CHECK_GE(b, 0);
CHECK_LT(b, called_computations_.size());
return called_computations_[b];
}
void HloInstruction::set_branch_computation(int b,
HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[kFalseComputationIndex] = false_computation;
called_computations_[b] = computation;
}
string HloInstruction::SignatureString() const {
@ -2207,10 +2264,21 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("scatter=", PrintName(scatter()->name(), options)));
} else if (opcode() == HloOpcode::kConditional) {
extra.push_back(StrCat("true_computation=",
PrintName(true_computation()->name(), options)));
extra.push_back(StrCat("false_computation=",
PrintName(false_computation()->name(), options)));
if (operand(0)->shape().element_type() == PRED) {
extra.push_back(StrCat("true_computation=",
PrintName(true_computation()->name(), options)));
extra.push_back(
StrCat("false_computation=",
PrintName(false_computation()->name(), options)));
} else {
extra.push_back(StrCat(
"branch_computations={",
StrJoin(branch_computations(), ", ",
[&](string* out, const HloComputation* computation) {
StrAppend(out, PrintName(computation->name(), options));
}),
"}"));
}
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
@ -2242,10 +2310,20 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
break;
case HloOpcode::kConditional:
extra.push_back(StrCat("true_computation=\n",
true_computation()->ToString(new_options)));
extra.push_back(StrCat("false_computation=\n",
false_computation()->ToString(new_options)));
if (operand(0)->shape().element_type() == PRED) {
extra.push_back(StrCat("true_computation=\n",
true_computation()->ToString(new_options)));
extra.push_back(StrCat("false_computation=\n",
false_computation()->ToString(new_options)));
} else {
extra.push_back(StrCat(
"branch_computations={\n",
StrJoin(branch_computations(), ",\n",
[&](string* out, const HloComputation* computation) {
StrAppend(out, computation->ToString(new_options));
}),
"\n}"));
}
break;
case HloOpcode::kCall:
case HloOpcode::kMap:

View File

@ -711,6 +711,11 @@ class HloInstruction {
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation);
static std::unique_ptr<HloInstruction> CreateConditional(
const Shape& shape, HloInstruction* branch_index,
absl::Span<HloComputation* const> branch_computations,
absl::Span<HloInstruction* const> branch_computation_args);
static std::unique_ptr<HloInstruction> CreateGather(
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
@ -1057,14 +1062,23 @@ class HloInstruction {
HloInstruction* while_init() const;
// Gets/sets the true and false HloComputation for Conditional. The setters
// should only be called by HloModule or HloComputation methods.
// Gets/sets the true and false HloComputation for Conditional.
//
// Precondition: The instruction is a Conditional instruction.
// Precondition: The instruction is a predicated Conditional instruction.
HloComputation* true_computation() const;
HloComputation* false_computation() const;
void set_true_computation(HloComputation* true_computation);
void set_false_computation(HloComputation* false_computation);
// Gets the branch HloComputations for Conditional.
//
// Precondition: The instruction is a Conditional instruction.
const std::vector<HloComputation*>& branch_computations() const;
int branch_count() const;
HloComputation* branch_computation(int b) const;
// Sets a branch HloComputation for Conditional.
// The setter should only be called by HloModule or HloComputation methods.
//
// Precondition: The instruction is a Conditional instruction.
void set_branch_computation(int b, HloComputation* computation);
// Returns a string for the signature of this instruction if considered as a
// function, e.g. the signature of an F32 add is (F32, F32) -> F32.

View File

@ -158,17 +158,12 @@ void HloModule::ReplaceComputations(
break;
}
case HloOpcode::kConditional: {
HloComputation* new_true_computation =
tensorflow::gtl::FindWithDefault(
replacements, instruction->true_computation(), nullptr);
if (new_true_computation != nullptr) {
instruction->set_true_computation(new_true_computation);
}
HloComputation* new_false_computation =
tensorflow::gtl::FindWithDefault(
replacements, instruction->false_computation(), nullptr);
if (new_false_computation != nullptr) {
instruction->set_false_computation(new_false_computation);
for (int b = 0; b < instruction->branch_count(); ++b) {
HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
replacements, instruction->branch_computation(b), nullptr);
if (new_computation != nullptr) {
instruction->set_branch_computation(b, new_computation);
}
}
break;
}

View File

@ -45,11 +45,8 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
case ComputationKind::kWhileBody:
repr += ":WHILE_BODY";
break;
case ComputationKind::kConditionalTrue:
repr += ":CONDITIONAL_TRUE";
break;
case ComputationKind::kConditionalFalse:
repr += ":CONDITIONAL_FALSE";
case ComputationKind::kConditionalBranch:
repr += absl::StrCat(":CONDITIONAL_BRANCH_", index_);
break;
case ComputationKind::kCallFunction:
repr += ":CALL";
@ -307,10 +304,10 @@ Status HloModuleGroupMetadata::RecordInstructions() {
tracked_instructions_[hlo->while_body()] =
TrackedInstruction(hlo, ComputationKind::kWhileBody);
} else if (hlo->opcode() == HloOpcode::kConditional) {
tracked_instructions_[hlo->true_computation()] =
TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
tracked_instructions_[hlo->false_computation()] =
TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
for (int b = 0; b < hlo->branch_count(); ++b) {
tracked_instructions_[hlo->branch_computation(b)] =
TrackedInstruction(hlo, ComputationKind::kConditionalBranch, b);
}
} else if (hlo->opcode() == HloOpcode::kCall) {
tracked_instructions_[hlo->to_apply()] =
TrackedInstruction(hlo, ComputationKind::kCallFunction);

View File

@ -67,8 +67,7 @@ class HloModuleGroupMetadata {
kInvalid,
kWhileCondition,
kWhileBody,
kConditionalTrue,
kConditionalFalse,
kConditionalBranch,
kCallFunction,
};
@ -80,12 +79,13 @@ class HloModuleGroupMetadata {
class TrackedInstruction {
public:
TrackedInstruction() = default;
TrackedInstruction(HloInstruction* instruction, ComputationKind kind)
: instruction_(instruction), kind_(kind) {}
TrackedInstruction(HloInstruction* instruction, ComputationKind kind,
int index = -1)
: instruction_(instruction), kind_(kind), index_(index) {}
bool operator==(const TrackedInstruction& rhs) const {
return instruction_->opcode() == rhs.instruction_->opcode() &&
kind_ == rhs.kind_;
kind_ == rhs.kind_ && index_ == rhs.index_;
}
bool operator!=(const TrackedInstruction& rhs) const {
return !operator==(rhs);
@ -98,6 +98,7 @@ class HloModuleGroupMetadata {
private:
HloInstruction* instruction_ = nullptr;
ComputationKind kind_ = ComputationKind::kInvalid;
int index_ = -1;
};
// Represents a channel and the instructions that form the channel.

View File

@ -67,7 +67,7 @@ namespace xla {
V(kClz, "count-leading-zeros", 1) \
V(kComplex, "complex", 2) \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
V(kConditional, "conditional", 3) \
V(kConditional, "conditional", kHloOpcodeIsVariadic) \
V(kConstant, "constant", 0) \
V(kConvert, "convert", 1) \
V(kConvolution, "convolution", 2) \

View File

@ -59,6 +59,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kAllToAll:
case HloOpcode::kCall:
case HloOpcode::kConcatenate:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:

View File

@ -66,24 +66,31 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
}
}
// If the common ancestor is a conditional instruction, even though the true
// and false computations are not really ordered per-se, we define the true
// computation to be ordered before the false one.
// This ensures that buffers can still be shared among the two computations
// If the common ancestor is a conditional instruction, even though the branch
// computations are not really ordered per-se, we define the 0th branch
// computation to be ordered before the 1st one, before the 2nd and so forth.
// This ensures that buffers can still be shared among branch computations
// as they will forcibly have disjoint liveness.
if (a_ancestor == b_ancestor &&
a_ancestor->opcode() == HloOpcode::kConditional) {
const HloComputation* true_computation = a_ancestor->true_computation();
const HloComputation* false_computation = a_ancestor->false_computation();
if (call_graph_->InstructionIsNestedIn(a, true_computation) &&
call_graph_->InstructionIsNestedIn(b, false_computation)) {
(a_ancestor->opcode() == HloOpcode::kConditional)) {
int a_branch = -1;
int b_branch = -1;
for (int j = 0; j < a_ancestor->branch_count(); ++j) {
if (call_graph_->InstructionIsNestedIn(
a, a_ancestor->branch_computation(j))) {
a_branch = j;
}
if (call_graph_->InstructionIsNestedIn(
b, a_ancestor->branch_computation(j))) {
b_branch = j;
}
}
if (a_branch != -1 && a_branch < b_branch) {
return true;
}
// If 'b' is the conditional ancestor, and 'a' is within the true or false
// computations, 'a' executes before 'b'.
if (b == a_ancestor &&
(call_graph_->InstructionIsNestedIn(a, true_computation) ||
call_graph_->InstructionIsNestedIn(a, false_computation))) {
// If 'b' is the conditional ancestor, and 'a' is within a branch
// computation, 'a' executes before 'b'.
if (b == a_ancestor && a_branch != -1) {
return true;
}
}
@ -144,17 +151,17 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
b.defining_instruction()->while_condition()))) {
return true;
}
// If 'b' is a conditional phi and 'a' is in the true or false computation,
// then 'a' executes before 'b'.
// If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
// executes before 'b'.
if (b.is_phi() &&
b.defining_instruction()->opcode() == HloOpcode::kConditional &&
(call_graph_->InstructionIsNestedIn(
a.defining_instruction(),
b.defining_instruction()->true_computation()) ||
call_graph_->InstructionIsNestedIn(
a.defining_instruction(),
b.defining_instruction()->false_computation()))) {
return true;
b.defining_instruction()->opcode() == HloOpcode::kConditional) {
for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
if (call_graph_->InstructionIsNestedIn(
a.defining_instruction(),
b.defining_instruction()->branch_computation(j))) {
return true;
}
}
}
return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
}
@ -225,17 +232,14 @@ bool HloOrdering::UseIsBeforeValueDefinition(
if (use.instruction->opcode() == HloOpcode::kConditional) {
const HloInstruction* conditional = use.instruction;
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
conditional->true_computation())) {
VLOG(4) << " use is conditional " << use.instruction->name()
<< " and def is in TRUE computation";
return true;
}
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
conditional->false_computation())) {
VLOG(4) << " use is conditional " << use.instruction->name()
<< " and def is in FALSE computation";
return true;
for (int j = 0; j < conditional->branch_count(); ++j) {
if (call_graph_->InstructionIsNestedIn(
value.defining_instruction(),
conditional->branch_computation(j))) {
VLOG(4) << " use is conditional " << use.instruction->name()
<< " and def is in " << j << "th branch computation";
return true;
}
}
if (value.defining_instruction() == use.instruction) {
VLOG(4) << " use is conditional " << use << " and def is "

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -180,6 +181,7 @@ class HloParser {
kBracedInt64List,
kBracedInt64ListList,
kHloComputation,
kBracedHloComputationList,
kFftType,
kWindow,
kConvolutionDimensionNumbers,
@ -276,6 +278,8 @@ class HloParser {
bool ParseSliceRanges(SliceRanges* result);
bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
bool ParseHloComputation(HloComputation** result);
bool ParseHloComputationList(std::vector<HloComputation*>* result);
bool ParseShapeList(std::vector<Shape>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim, std::vector<int64>* result);
@ -1436,18 +1440,36 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kConditional: {
optional<HloComputation*> true_computation;
optional<HloComputation*> false_computation;
attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
&true_computation};
attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation,
&false_computation};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
optional<std::vector<HloComputation*>> branch_computations;
if (!ParseOperands(&operands)) {
return false;
}
const bool branch_index_is_bool =
operands[0]->shape().element_type() == PRED;
if (branch_index_is_bool) {
attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
&true_computation};
attrs["false_computation"] = {
/*required=*/true, AttrTy::kHloComputation, &false_computation};
} else {
attrs["branch_computations"] = {/*required=*/true,
AttrTy::kBracedHloComputationList,
&branch_computations};
}
if (!ParseAttributes(attrs)) {
return false;
}
if (branch_index_is_bool) {
branch_computations.emplace({*true_computation, *false_computation});
}
if (branch_computations->empty() ||
operands.size() != branch_computations->size() + 1) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConditional(
shape, /*pred=*/operands[0],
/*true_computation_arg=*/operands[1], *true_computation,
/*false_computation_arg=*/operands[2], *false_computation));
shape, /*branch_index=*/operands[0],
absl::MakeSpan(*branch_computations),
absl::MakeSpan(operands).subspan(1)));
break;
}
case HloOpcode::kCustomCall: {
@ -2683,20 +2705,21 @@ bool HloParser::ParseAttributeHelper(
}
case AttrTy::kHloComputation: {
HloComputation* result = nullptr;
if (lexer_.GetKind() == TokKind::kLbrace) {
// This means it is a nested computation.
if (!ParseInstructionList(&result, /*computation_name=*/"_")) {
return false;
}
} else {
// This means it is a computation name.
if (!ParseComputationName(&result)) {
return false;
}
if (!ParseHloComputation(&result)) {
return false;
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kBracedHloComputationList: {
std::vector<HloComputation*> result;
if (!ParseHloComputationList(&result)) {
return false;
}
static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFftType: {
FftType result;
if (!ParseFftType(&result)) {
@ -3230,6 +3253,29 @@ bool HloParser::ParsePrecisionList(
parse_and_add_item);
}
bool HloParser::ParseHloComputation(HloComputation** result) {
if (lexer_.GetKind() == TokKind::kLbrace) {
// This means it is a nested computation.
return ParseInstructionList(result, /*computation_name=*/"_");
}
// This means it is a computation name.
return ParseComputationName(result);
}
bool HloParser::ParseHloComputationList(std::vector<HloComputation*>* result) {
auto parse_and_add_item = [&]() {
HloComputation* computation;
if (!ParseHloComputation(&computation)) {
return false;
}
LOG(INFO) << "parsed computation " << computation->name();
result->push_back(computation);
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
// shapelist ::= '{' shapes '}'
// precision_elements
// ::= /*empty*/

View File

@ -952,6 +952,36 @@ ENTRY %ParseC128Literal () -> c128[2] {
ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
}
)"
},
// Indexed Conditional
{
"IndexedConditional",
R"(HloModule indexed_conditional
%Negate (x: f32[]) -> f32[] {
%x = f32[] parameter(0)
ROOT %negate = f32[] negate(f32[] %x)
}
%Identity (y: f32[]) -> f32[] {
%y = f32[] parameter(0)
ROOT %copy = f32[] copy(f32[] %y)
}
%Floor (z: f32[]) -> f32[] {
%z = f32[] parameter(0)
ROOT %floor = f32[] floor(f32[] %z)
}
ENTRY %Parameters1.v4 () -> f32[] {
%constant = s32[] constant(1)
%constant.1 = f32[] constant(56)
%constant.2 = f32[] constant(12)
%constant.3 = f32[] constant(13)
ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
}
)"
},
});
@ -1191,10 +1221,40 @@ ENTRY Sort {
)"
},
// Conditional
// Indexed Conditional
{
"Conditional",
R"(HloModule conditional
"IndexedConditional",
R"(HloModule indexed_conditional
Negate {
x = f32[] parameter(0)
ROOT negate = f32[] negate(x)
}
Identity {
y = f32[] parameter(0)
ROOT copy = f32[] copy(y)
}
Floor {
z = f32[] parameter(0)
ROOT floor = f32[] floor(z)
}
ENTRY Parameters1.v4 {
constant = s32[] constant(1)
constant.1 = f32[] constant(56)
constant.2 = f32[] constant(12)
constant.3 = f32[] constant(13)
ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
}
)"
},
// Predicated Conditional
{
"PredicatedConditional",
R"(HloModule pred_conditional
Negate {
x = f32[] parameter(0)
@ -2317,6 +2377,31 @@ TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
text);
}
TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
const string text =
R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
{
tmp_0 = f32[5,10]{1,0} parameter(0)
ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
},
{
tmp_0 = f32[5,10]{1,0} parameter(0)
ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
},
{
tmp_0 = f32[5,10]{1,0} parameter(0)
ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_EQ(
computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
text);
}
TEST(HloParserSingleOpTest, SingleOpWithNested) {
const string text =
R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=

View File

@ -667,20 +667,22 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
}
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
TF_RETURN_IF_ERROR(
CheckParameterCount(conditional, conditional->true_computation(), 1));
TF_RETURN_IF_ERROR(
CheckParameterCount(conditional, conditional->false_computation(), 1));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, 1, conditional->true_computation(), 0));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, 2, conditional->false_computation(), 0));
TF_RETURN_IF_ERROR(
CheckShape(conditional,
conditional->true_computation()->root_instruction()->shape()));
TF_RETURN_IF_ERROR(CheckShape(
conditional,
conditional->false_computation()->root_instruction()->shape()));
const int num_branches = conditional->branch_count();
if (conditional->operand(0)->shape().element_type() == PRED) {
TF_RET_CHECK(num_branches == 2);
} else {
TF_RET_CHECK(num_branches >= 1);
}
TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
for (int j = 0; j < num_branches; ++j) {
TF_RETURN_IF_ERROR(CheckParameterCount(
conditional, conditional->branch_computation(j), 1));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, j + 1, conditional->branch_computation(j), 0));
TF_RETURN_IF_ERROR(CheckShape(
conditional,
conditional->branch_computation(j)->root_instruction()->shape()));
}
return Status::OK();
}
@ -1370,17 +1372,13 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}
Status HandleConditional(HloInstruction* conditional) override {
if (conditional->true_computation()->num_parameters() != 1) {
return FailedPrecondition(
"True computation %s of %s must have 1 parameter insted of %d",
conditional->true_computation()->name(), conditional->ToString(),
conditional->true_computation()->num_parameters());
}
if (conditional->false_computation()->num_parameters() != 1) {
return FailedPrecondition(
"False computation %s of %s must have 1 parameter insted of %d",
conditional->false_computation()->name(), conditional->ToString(),
conditional->false_computation()->num_parameters());
for (int b = 0; b < conditional->branch_count(); ++b) {
if (conditional->branch_computation(b)->num_parameters() != 1) {
return FailedPrecondition(
"Branch computation %s of %s must have 1 parameter insted of %d",
conditional->branch_computation(b)->name(), conditional->ToString(),
conditional->branch_computation(b)->num_parameters());
}
}
return Status::OK();
}

View File

@ -588,48 +588,40 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
body_layout.result_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kConditional) {
// The layout of the true and false computations must match, and must
// The layout of the branch computations must match, and must
// be the layout of the kConditional instruction.
TF_RET_CHECK(instruction->operand_count() == 3);
ComputationLayout& branch0_computation_layout =
FindOrDie(computation_layouts_, instruction->branch_computation(0));
for (int j = 0; j < instruction->branch_count(); ++j) {
TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
ComputationLayout& branch_computation_layout =
FindOrDie(computation_layouts_, instruction->branch_computation(j));
HloComputation* true_computation = instruction->true_computation();
HloComputation* false_computation = instruction->false_computation();
const HloInstruction* true_operand = instruction->operand(1);
const HloInstruction* false_operand = instruction->operand(2);
TF_RET_CHECK(true_computation->num_parameters() == 1);
TF_RET_CHECK(false_computation->num_parameters() == 1);
ComputationLayout& true_computation_layout =
FindOrDie(computation_layouts_, true_computation);
ComputationLayout& false_computation_layout =
FindOrDie(computation_layouts_, false_computation);
DCHECK(ShapeUtil::Compatible(true_operand->shape(),
true_computation_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(
false_operand->shape(), false_computation_layout.parameter_shape(0)));
if (true_computation_layout.result_layout() !=
false_computation_layout.result_layout()) {
// We assign layouts in DFS fashion, so the true and false computations
// might have negotiated a different layout. But for the conditional
// instruction POV the layout must match, so we run again on the false
// computation, this time with proper computation layout.
VLOG(2) << "Reset %conditional false computation result layout: "
"false_computation="
<< false_computation->name()
<< " conditional=" << instruction->name() << " shape="
<< true_computation_layout.result_layout().ToString();
*false_computation_layout.mutable_result_layout() =
true_computation_layout.result_layout();
DCHECK(ShapeUtil::Compatible(
instruction->operand(j + 1)->shape(),
branch_computation_layout.parameter_shape(0)));
if (branch0_computation_layout.result_layout() !=
branch_computation_layout.result_layout()) {
// We assign layouts in DFS fashion, so the br_0 and br_j
// computations might have negotiated a different layout. But for the
// case instruction POV the layout must match, so we run again
// on the br_j computation, this time with proper computation layout.
VLOG(2) << "Reset %conditional branch " << j
<< " computation result layout: branch_computation="
<< instruction->branch_computation(j)->name()
<< " case=" << instruction->name() << " shape="
<< branch0_computation_layout.result_layout().ToString();
*branch_computation_layout.mutable_result_layout() =
branch0_computation_layout.result_layout();
}
if (j == 0) {
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
branch0_computation_layout.result_shape(), instruction));
}
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
branch_computation_layout.parameter_shape(0), instruction, j + 1,
/*mandatory=*/true));
}
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
true_computation_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
true_computation_layout.parameter_shape(0), instruction, 1,
/*mandatory=*/true));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
false_computation_layout.parameter_shape(0), instruction, 2,
/*mandatory=*/true));
}
}
// Finally set the result layout to match ComputationLayout, if there is one.
@ -699,28 +691,21 @@ Status CheckWhileLayout(HloInstruction* while_inst,
Status CheckConditionalLayout(
HloInstruction* instruction,
const ComputationLayout& true_computation_layout,
const ComputationLayout& false_computation_layout) {
HloComputation* true_computation = instruction->true_computation();
HloComputation* false_computation = instruction->false_computation();
const HloInstruction* true_operand = instruction->operand(1);
const HloInstruction* false_operand = instruction->operand(2);
TF_RET_CHECK(true_computation_layout.result_layout() ==
false_computation_layout.result_layout());
TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
instruction->shape()));
TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
true_computation->root_instruction()->shape()));
TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
instruction->shape()));
TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
false_computation->root_instruction()->shape()));
TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape(
true_operand->shape()));
TF_RET_CHECK(
false_computation_layout.parameter_layout(0).MatchesLayoutInShape(
false_operand->shape()));
absl::Span<const ComputationLayout> branch_computation_layouts) {
for (int j = 0; j < instruction->branch_count(); ++j) {
const HloInstruction* branch_operand = instruction->operand(j + 1);
TF_RET_CHECK(branch_computation_layouts[0].result_layout() ==
branch_computation_layouts[j].result_layout());
TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->shape()));
TF_RET_CHECK(
branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
instruction->branch_computation(j)->root_instruction()->shape()));
TF_RET_CHECK(
branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
branch_operand->shape()));
}
return Status::OK();
}
@ -937,13 +922,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, instruction->while_condition()),
FindOrDie(computation_layouts_, instruction->while_body())));
break;
case HloOpcode::kConditional:
case HloOpcode::kConditional: {
std::vector<ComputationLayout> branch_computation_layouts;
for (auto branch_computation : instruction->branch_computations()) {
branch_computation_layouts.emplace_back(
FindOrDie(computation_layouts_, branch_computation));
}
TF_RETURN_IF_ERROR(CheckConditionalLayout(
instruction,
FindOrDie(computation_layouts_, instruction->true_computation()),
FindOrDie(computation_layouts_,
instruction->false_computation())));
instruction, absl::MakeSpan(branch_computation_layouts)));
break;
}
default:
break;
}

View File

@ -1291,16 +1291,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RETURN_IF_ERROR(
ExpectArray(scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) ==
Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
Status::OK());
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
if (feature_index >= operand_shape.rank()) {
return InvalidArgument(
@ -2562,59 +2558,55 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
const Shape& predicate, const Shape& true_operand,
const Shape& false_operand, const ProgramShape& true_computation,
const ProgramShape& false_computation) {
if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) {
return InvalidArgument("Predicate must be a boolean; got %s.",
ShapeUtil::HumanString(predicate));
const Shape& branch_index,
absl::Span<const ProgramShape> branch_computations,
absl::Span<const Shape> branch_operands) {
if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) {
return InvalidArgument("branch_index must be bool or int32; got %s.",
ShapeUtil::HumanString(branch_index));
}
if (branch_index.element_type() == PRED) {
TF_RET_CHECK(2 == branch_computations.size());
} else {
TF_RET_CHECK(!branch_computations.empty());
}
TF_RET_CHECK(branch_computations.size() == branch_operands.size());
if (true_computation.parameters_size() != 1) {
return InvalidArgument("true_computation must take 1 argument; got %d.",
true_computation.parameters_size());
}
if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) {
auto true_shape_string = [&]() {
return StrFormat("true_operand: %s; true_computation: %s",
ShapeUtil::HumanString(true_operand),
ShapeUtil::HumanString(true_computation));
};
return InvalidArgument(
"true_operand must match the shape of the only parameter of "
"true_computation: got %s.",
true_shape_string());
}
for (int j = 0; j < branch_computations.size(); ++j) {
if (branch_computations[j].parameters_size() != 1) {
return InvalidArgument(
"branch computation %d must take 1 argument; got %d.", j,
branch_computations[j].parameters_size());
}
if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
branch_operands[j])) {
auto shape_string = [&]() {
return StrFormat("operand: %s; computation: %s",
ShapeUtil::HumanString(branch_operands[j]),
ShapeUtil::HumanString(branch_computations[j]));
};
return InvalidArgument(
"branch operand %d must match the shape of the only parameter of "
"branch computation %d: got %s.",
j, j, shape_string());
}
if (false_computation.parameters_size() != 1) {
return InvalidArgument("false_computation must take 1 argument; got %d.",
false_computation.parameters_size());
if (!ShapeUtil::Compatible(branch_computations[0].result(),
branch_computations[j].result())) {
auto shape_string = [&]() {
return StrFormat(
"branch 0 computation result: %s; branch %d computation result: %s",
ShapeUtil::HumanString(branch_computations[0].result()), j,
ShapeUtil::HumanString(branch_computations[j].result()));
};
return InvalidArgument(
"the result of branch 0 computation and branch %d computation must "
"have the same shape: got %s.",
j, shape_string());
}
}
if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) {
auto false_shape_string = [&]() {
return StrFormat("false_operand: %s; false_computation: %s",
ShapeUtil::HumanString(false_operand),
ShapeUtil::HumanString(false_computation));
};
return InvalidArgument(
"false_operand must match the shape of the only parameter of "
"false_computation: got %s.",
false_shape_string());
}
if (!ShapeUtil::Compatible(true_computation.result(),
false_computation.result())) {
auto shape_string = [&]() {
return StrFormat(
"true_computation result: %s; false_computation result: %s.",
ShapeUtil::HumanString(true_computation.result()),
ShapeUtil::HumanString(false_computation.result()));
};
return InvalidArgument(
"the result of true_computation and false_computation must have the "
"same shape: got %s.",
shape_string());
}
return true_computation.result();
return branch_computations[0].result();
}
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(

View File

@ -208,11 +208,11 @@ class ShapeInference {
const ProgramShape& body,
const Shape& init);
// Infers the shape produced by a conditional operation.
// Infers the shape produced by a predicated or indexed conditional operation.
static StatusOr<Shape> InferConditionalShape(
const Shape& predicate, const Shape& true_operand,
const Shape& false_operand, const ProgramShape& true_computation,
const ProgramShape& false_computation);
const Shape& branch_index,
absl::Span<const ProgramShape> branch_computations,
absl::Span<const Shape> branch_operands);
// Infers the shape produced by a broadcast operation.
static StatusOr<Shape> InferBroadcastShape(

View File

@ -1582,79 +1582,166 @@ TEST_F(ShapeInferenceTest, Rank1Transpose) {
ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
}
TEST_F(ShapeInferenceTest, Conditional) {
TEST_F(ShapeInferenceTest, ConditionalPred) {
auto inferred_status0 = ShapeInference::InferConditionalShape(
pred_, vector_32_, vector_64_,
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_));
pred_,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
{vector_32_, vector_64_});
EXPECT_IS_OK(inferred_status0.status());
EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
auto inferred_status1 = ShapeInference::InferConditionalShape(
pred_, matrix_32_48_, vector_32_,
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
ShapeUtil::MakeProgramShape({vector_32_}, vector_64_));
pred_,
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
{matrix_32_48_, vector_32_});
EXPECT_IS_OK(inferred_status1.status());
EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
auto inferred_status2 = ShapeInference::InferConditionalShape(
pred_, matrix_32_48_, tuple_f32_v32,
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_));
pred_,
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
{matrix_32_48_, tuple_f32_v32});
EXPECT_IS_OK(inferred_status2.status());
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
auto inferred_status_error0 = ShapeInference::InferConditionalShape(
s32_, vector_32_, vector_64_,
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_));
f32_,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
{vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error0.ok());
EXPECT_THAT(inferred_status_error0.status().error_message(),
HasSubstr("Predicate must be a boolean"));
HasSubstr("must be bool or int32"));
auto inferred_status_error1 = ShapeInference::InferConditionalShape(
pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_,
ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_));
pred_,
{ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
{ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
EXPECT_FALSE(inferred_status_error1.ok());
EXPECT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("true_computation must take 1 argument"));
HasSubstr("branch computation 0 must take 1 argument"));
auto inferred_status_error2 = ShapeInference::InferConditionalShape(
pred_, vector_32_, vector_64_,
ShapeUtil::MakeProgramShape({vector_64_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_));
pred_,
{ShapeUtil::MakeProgramShape({vector_64_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
{vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error2.ok());
EXPECT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("true_operand must match the shape of the only "
"parameter of true_computation"));
HasSubstr("branch operand 0 must match the shape of the only "
"parameter of branch computation 0"));
auto inferred_status_error3 = ShapeInference::InferConditionalShape(
pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_));
pred_,
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
{matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
EXPECT_FALSE(inferred_status_error3.ok());
EXPECT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("false_computation must take 1 argument"));
HasSubstr("branch computation 1 must take 1 argument"));
auto inferred_status_error4 = ShapeInference::InferConditionalShape(
pred_, vector_32_, vector_64_,
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_));
pred_,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
{vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error4.ok());
EXPECT_THAT(inferred_status_error4.status().error_message(),
HasSubstr("false_operand must match the shape of the only "
"parameter of false_computation"));
HasSubstr("branch operand 1 must match the shape of the only "
"parameter of branch computation 1"));
auto inferred_status_error5 = ShapeInference::InferConditionalShape(
pred_, vector_32_, vector_64_,
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, vector_32_));
pred_,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
{vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error5.ok());
EXPECT_THAT(inferred_status_error5.status().error_message(),
HasSubstr("the result of true_computation and false_computation "
"must have the same shape"));
HasSubstr("the result of branch 0 computation and branch 1 "
"computation must have the same shape"));
}
TEST_F(ShapeInferenceTest, ConditionalIndexed) {
auto r0s32 = ShapeUtil::MakeShape(S32, {});
auto inferred_status0 = ShapeInference::InferConditionalShape(
r0s32,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
{vector_32_, vector_64_, vector_64_});
EXPECT_IS_OK(inferred_status0.status());
EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
auto inferred_status1 = ShapeInference::InferConditionalShape(
r0s32,
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
{matrix_32_48_, vector_32_, matrix_32_48_});
EXPECT_IS_OK(inferred_status1.status());
EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
auto inferred_status2 = ShapeInference::InferConditionalShape(
r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
{tuple_f32_v32});
EXPECT_IS_OK(inferred_status2.status());
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
auto inferred_status_error0 = ShapeInference::InferConditionalShape(
pred_,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
{vector_32_, vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error0.ok());
EXPECT_THAT(inferred_status_error0.status().error_message(),
HasSubstr("2 == branch_computations.size()"));
auto inferred_status_error1 = ShapeInference::InferConditionalShape(
r0s32,
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
{matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
matrix_32_48_});
EXPECT_FALSE(inferred_status_error1.ok());
EXPECT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("branch computation 1 must take 1 argument"));
auto inferred_status_error2 = ShapeInference::InferConditionalShape(
r0s32,
{ShapeUtil::MakeProgramShape({r0s32}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
{r0s32, vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error2.ok());
EXPECT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("branch operand 2 must match the shape of the only "
"parameter of branch computation 2"));
auto inferred_status_error3 = ShapeInference::InferConditionalShape(
r0s32,
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
{vector_32_, vector_32_, vector_32_, vector_64_});
EXPECT_FALSE(inferred_status_error3.ok());
EXPECT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("the result of branch 0 computation and branch 3 "
"computation must have the same shape"));
auto inferred_status_error4 =
ShapeInference::InferConditionalShape(r0s32, {}, {});
EXPECT_FALSE(inferred_status_error4.ok());
EXPECT_THAT(inferred_status_error4.status().error_message(),
HasSubstr("!branch_computations.empty()"));
}
TEST_F(ShapeInferenceTest, BadSlice) {

View File

@ -552,6 +552,7 @@ xla_test(
xla_test(
name = "conditional_test",
srcs = ["conditional_test.cc"],
shard_count = 2,
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <random>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@ -169,6 +170,11 @@ class ConditionalOpTest : public ClientLibraryTestBase {
ErrorSpec error_spec_{0.001};
};
// Test fixture to run indexed conditional (switch/case) tests with varying
// number of branches.
class CaseOpTest : public ConditionalOpTest,
public ::testing::WithParamInterface<int> {};
// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
XlaBuilder builder(TestName());
@ -182,6 +188,36 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) {
ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
}
// Test branch computations that do not take any parameters.
XLA_TEST_P(CaseOpTest, Parameters0) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
&builder, &branch_index);
auto operand = Tuple(&builder, {});
std::vector<XlaOp> operands(num_branches, operand);
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(
CreateR0ConstantComputation(static_cast<float>(i) * 10));
branches_p[i] = &branches[i];
}
Conditional(branch_index, branches_p, operands);
float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
? num_branches - 1
: bi);
ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
error_spec_);
}
}
// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
XlaBuilder builder(TestName());
@ -195,6 +231,45 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) {
ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test branch computations that take in 1 parameter.
XLA_TEST_P(CaseOpTest, Parameters1) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
&builder, &branch_index);
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
Parameter(sb.get(), 0, r0f32_, "p0"));
return sb->BuildAndNoteError();
};
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
std::vector<XlaOp> operands;
operands.reserve(num_branches);
std::vector<float> expecteds(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i));
branches_p[i] = &branches[i];
auto fi = static_cast<float>(i);
operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
expecteds[i] = 10 * fi + 7 + fi;
}
Conditional(branch_index, branches_p, operands);
float expected = (bi < 0 || bi >= num_branches)
? expecteds[num_branches - 1]
: expecteds[bi];
ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
error_spec_);
}
}
// Test conditional with two different computations in the true and false cases
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
@ -331,6 +406,46 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
error_spec_);
}
// Test branch computations that take in 2 array parameters.
XLA_TEST_P(CaseOpTest, Parameters2Array) {
int num_branches = GetParam();
for (int bi = -1; bi <= num_branches; ++bi) {
SCOPED_TRACE(bi);
XlaBuilder builder(TestName());
XlaOp branch_index;
auto branch_index_arg =
CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
auto make_branch = [&builder, this](int i) {
auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
GetTupleElement(p, 0)),
GetTupleElement(p, 1));
return sb->BuildAndNoteError();
};
std::vector<XlaComputation> branches;
branches.reserve(num_branches);
std::vector<const XlaComputation*> branches_p(num_branches);
for (int i = 0; i < num_branches; ++i) {
branches.emplace_back(make_branch(i));
branches_p[i] = &branches[i];
}
Conditional(branch_index, branches_p,
std::vector<XlaOp>(num_branches, operands));
auto modified_bi = static_cast<float>(
(bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
ComputeAndCompareR1<float>(
&builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
{branch_index_arg.get()}, error_spec_);
}
}
INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest,
::testing::Values(1, 2, 3, 4, 5));
// Test true and false computations that take in 2 array parameters and
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
@ -582,8 +697,8 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
auto result = builder.Build();
EXPECT_FALSE(result.ok());
EXPECT_THAT(result.status().error_message(),
::testing::HasSubstr("true_operand must match the shape of the "
"only parameter of true_computation"));
::testing::HasSubstr("operand 0 must match the shape of the "
"only parameter of branch computation 0"));
}
XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {