Modifies kConditional to support both predicated and indexed conditionals (i.e. switch statements).
Updates lowerings for both CPU and GPU. Adds a new tf2xla kernel for the CFv2 functional Case op which lowers to an indexed kConditional. PiperOrigin-RevId: 235831207
This commit is contained in:
		
							parent
							
								
									39b741fd9a
								
							
						
					
					
						commit
						9bac04a4de
					
				| @ -115,6 +115,7 @@ tf_kernel_library( | ||||
|     ], | ||||
|     tags = ["optonly"], | ||||
|     deps = [ | ||||
|         ":case_op", | ||||
|         ":conv_op_helpers", | ||||
|         ":if_op", | ||||
|         ":while_op", | ||||
| @ -269,6 +270,23 @@ tf_kernel_library( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_kernel_library( | ||||
|     name = "case_op", | ||||
|     srcs = ["case_op.cc"], | ||||
|     hdrs = ["case_op.h"], | ||||
|     deps = [ | ||||
|         "//tensorflow/compiler/tf2xla:common", | ||||
|         "//tensorflow/compiler/tf2xla:side_effect_util", | ||||
|         "//tensorflow/compiler/tf2xla:xla_compiler", | ||||
|         "//tensorflow/compiler/tf2xla/ops:xla_ops", | ||||
|         "//tensorflow/compiler/xla:literal", | ||||
|         "//tensorflow/compiler/xla/client:xla_builder", | ||||
|         "//tensorflow/core:framework", | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| # Kernels that have a dummy (no-op) implementation. | ||||
| tf_kernel_library( | ||||
|     name = "xla_dummy_ops", | ||||
|  | ||||
							
								
								
									
										297
									
								
								tensorflow/compiler/tf2xla/kernels/case_op.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										297
									
								
								tensorflow/compiler/tf2xla/kernels/case_op.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,297 @@ | ||||
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/compiler/tf2xla/kernels/case_op.h" | ||||
| 
 | ||||
| #include "tensorflow/compiler/tf2xla/shape_util.h" | ||||
| #include "tensorflow/compiler/tf2xla/side_effect_util.h" | ||||
| #include "tensorflow/compiler/tf2xla/xla_context.h" | ||||
| #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" | ||||
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" | ||||
| #include "tensorflow/compiler/xla/client/xla_builder.h" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { | ||||
|   OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_)); | ||||
|   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); | ||||
|   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); | ||||
|   if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { | ||||
|     has_token_input_output_ = false; | ||||
|   } else { | ||||
|     has_token_input_output_ = !token_input_nodes_.empty(); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // TODO(b/35949885): There is duplication here with the handling of the
 | ||||
| // while_op. Refactor the common code out/rework.
 | ||||
| void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { | ||||
|   xla::XlaBuilder* b = ctx->builder(); | ||||
|   int num_branches = branches_.size(); | ||||
|   OP_REQUIRES(ctx, num_branches >= 1, | ||||
|               errors::InvalidArgument("Must provide at least one case branch")); | ||||
|   OP_REQUIRES(ctx, input_type(0) == DT_INT32, | ||||
|               errors::InvalidArgument( | ||||
|                   "branch_index argument must be a int32 for XLA compilation")); | ||||
|   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), | ||||
|               errors::InvalidArgument( | ||||
|                   "branch_index argument must be scalar for XLA compilation")); | ||||
| 
 | ||||
|   VLOG(1) << "Building Case: " << input_types_.size() << " inputs"; | ||||
| 
 | ||||
|   std::vector<XlaCompiler::Argument> arguments(input_types_.size()); | ||||
|   int num_resource_args = 0; | ||||
|   for (int i = 0; i < input_types_.size(); ++i) { | ||||
|     XlaCompiler::Argument& arg = arguments[i]; | ||||
|     DataType type = ctx->input_type(i + 1); | ||||
| 
 | ||||
|     if (type == DT_RESOURCE) { | ||||
|       XlaResource* resource; | ||||
|       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); | ||||
| 
 | ||||
|       arg.initialized = resource->initialized(); | ||||
|       arg.kind = XlaCompiler::Argument::kResource; | ||||
|       arg.resource_kind = resource->kind(); | ||||
| 
 | ||||
|       arg.type = resource->type(); | ||||
|       arg.shape = resource->shape(); | ||||
|       OP_REQUIRES(ctx, arg.initialized, | ||||
|                   errors::Unimplemented("Uninitialized arguments: ", arg.name)); | ||||
|       arg.max_array_size = resource->max_array_size(); | ||||
|       for (const auto& gradient : resource->tensor_array_gradients()) { | ||||
|         arg.tensor_array_gradients.insert(gradient.first); | ||||
|       } | ||||
|       arg.name = resource->name(); | ||||
|       VLOG(2) << "Resource " << resource->name() | ||||
|               << " type: " << DataTypeString(arg.type) | ||||
|               << " shape: " << arg.HumanString() | ||||
|               << " initialized: " << arg.initialized; | ||||
| 
 | ||||
|       num_resource_args++; | ||||
|     } else { | ||||
|       arg.kind = XlaCompiler::Argument::kParameter; | ||||
|       arg.type = input_types_[i]; | ||||
|       arg.shape = ctx->InputShape(i + 1); | ||||
|       VLOG(2) << "Arg type: " << DataTypeString(arg.type) | ||||
|               << " shape: " << arg.HumanString(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Compile each branch of the conditional.
 | ||||
|   XlaCompiler::CompileOptions options; | ||||
|   options.use_tuple_arg = true; | ||||
|   options.resolve_compile_time_constants = false; | ||||
|   options.return_updated_values_for_all_resources = true; | ||||
|   options.is_entry_computation = false; | ||||
|   options.add_token_input_output = has_token_input_output_; | ||||
|   XlaCompiler* compiler = ctx->compiler(); | ||||
| 
 | ||||
|   std::vector<XlaCompiler::CompilationResult> branch_results(num_branches); | ||||
|   std::vector<XlaCompiler::CompilationResult*> branch_results_p(num_branches); | ||||
|   for (int j = 0; j < num_branches; ++j) { | ||||
|     OP_REQUIRES_OK(ctx, | ||||
|                    compiler->CompileFunction(options, branches_[j], arguments, | ||||
|                                              &branch_results[j])); | ||||
|     branch_results_p[j] = &branch_results[j]; | ||||
|   } | ||||
| 
 | ||||
|   bool has_tensor_array_gradients = false; | ||||
|   for (XlaCompiler::CompilationResult* result : branch_results_p) { | ||||
|     for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { | ||||
|       XlaResource* resource; | ||||
|       OP_REQUIRES_OK(ctx, | ||||
|                      ctx->GetResourceInput(update.input_index + 1, &resource)); | ||||
|       XlaCompiler::Argument& arg = arguments[update.input_index]; | ||||
| 
 | ||||
|       // Add any TensorArray gradients touched by the then/else computation to
 | ||||
|       // the enclosing graph.
 | ||||
|       for (const string& grad_source : update.tensor_array_gradients_accessed) { | ||||
|         VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " | ||||
|                 << grad_source; | ||||
|         XlaResource* gradient; | ||||
|         OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( | ||||
|                                 grad_source, b, &gradient)); | ||||
|       } | ||||
|       // Add all of the TensorArray gradients to the argument. For simplicity,
 | ||||
|       // we always pass all known gradients.
 | ||||
|       for (const auto& gradient : resource->tensor_array_gradients()) { | ||||
|         arg.tensor_array_gradients.insert(gradient.first); | ||||
|       } | ||||
|       if (!resource->tensor_array_gradients().empty()) { | ||||
|         has_tensor_array_gradients = true; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Recompile the functions to update the argument shapes for tensor arrays.
 | ||||
|   if (has_tensor_array_gradients) { | ||||
|     for (int j = 0; j < num_branches; ++j) { | ||||
|       branch_results[j] = {}; | ||||
|       OP_REQUIRES_OK(ctx, | ||||
|                      compiler->CompileFunction(options, branches_[j], arguments, | ||||
|                                                &branch_results[j])); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   xla::Shape branch0_input_shape; | ||||
|   std::vector<const xla::XlaComputation*> result_computations(num_branches); | ||||
|   for (int j = 0; j < num_branches; ++j) { | ||||
|     // Check that all branches have identical input shapes.
 | ||||
|     OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1, | ||||
|                 errors::FailedPrecondition("Expected one input shape")); | ||||
|     xla::Shape branch_input_shape = branch_results[j].xla_input_shapes[0]; | ||||
|     if (j == 0) { | ||||
|       branch0_input_shape = branch_input_shape; | ||||
|     } | ||||
|     OP_REQUIRES(ctx, branch_input_shape.IsTuple(), | ||||
|                 errors::FailedPrecondition("Expected tuple shape")); | ||||
|     OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1, | ||||
|                 errors::FailedPrecondition("Expected one input shape")); | ||||
|     OP_REQUIRES( | ||||
|         ctx, | ||||
|         xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape), | ||||
|         errors::InvalidArgument( | ||||
|             "Input shapes of 0 and ", j, " branches do not match: ", | ||||
|             xla::ShapeUtil::HumanString(branch0_input_shape), " vs. ", | ||||
|             xla::ShapeUtil::HumanString(branch_input_shape))); | ||||
| 
 | ||||
|     // Check that all branches have identical output shapes.
 | ||||
|     OP_REQUIRES( | ||||
|         ctx, | ||||
|         xla::ShapeUtil::Compatible(branch_results[0].xla_output_shape, | ||||
|                                    branch_results[j].xla_output_shape), | ||||
|         errors::InvalidArgument( | ||||
|             "Output shapes of 0 and ", j, " branches do not match: ", | ||||
|             xla::ShapeUtil::HumanString(branch_results[0].xla_output_shape), | ||||
|             " vs. ", | ||||
|             xla::ShapeUtil::HumanString(branch_results[j].xla_output_shape))); | ||||
| 
 | ||||
|     if (j == 0) { | ||||
|       VLOG(2) << "Input shape: " | ||||
|               << xla::ShapeUtil::HumanString(branch0_input_shape); | ||||
|       VLOG(2) << "Output shape: " | ||||
|               << xla::ShapeUtil::HumanString( | ||||
|                      branch_results[0].xla_output_shape); | ||||
|     } | ||||
| 
 | ||||
|     // We set return_updated_values_for_all_resources=true and we pass the same
 | ||||
|     // arguments to both computations, so the resource update count must match.
 | ||||
|     OP_REQUIRES(ctx, | ||||
|                 branch_results[0].resource_updates.size() == | ||||
|                     branch_results[j].resource_updates.size(), | ||||
|                 errors::FailedPrecondition( | ||||
|                     "Different number of resources in 0 and ", j, " branch")); | ||||
|     for (int i = 0; i < branch_results[0].resource_updates.size(); ++i) { | ||||
|       const auto& lhs = branch_results[0].resource_updates[i]; | ||||
|       const auto& rhs = branch_results[j].resource_updates[i]; | ||||
|       bool equal = lhs.input_index == rhs.input_index && | ||||
|                    lhs.shape == rhs.shape && | ||||
|                    lhs.tensor_array_gradients_accessed == | ||||
|                        rhs.tensor_array_gradients_accessed; | ||||
|       OP_REQUIRES(ctx, equal, | ||||
|                   errors::FailedPrecondition("Mismatch in resource of 0 and ", | ||||
|                                              j, " branch for resource ", i)); | ||||
|     } | ||||
|     result_computations[j] = branch_results[j].computation.get(); | ||||
|   } | ||||
| 
 | ||||
|   // Prepare the input arg Tuple.
 | ||||
|   int num_inputs = branch_results[0].input_mapping.size(); | ||||
|   std::vector<xla::XlaOp> inputs(num_inputs); | ||||
|   for (int i = 0; i < num_inputs; ++i) { | ||||
|     int input_num = branch_results[0].input_mapping[i] + 1; | ||||
|     if (has_token_input_output_ && i == num_inputs - 1) { | ||||
|       // Set token input for this "case" op.
 | ||||
|       std::vector<xla::XlaOp> token_inputs; | ||||
|       for (const string& node_name : token_input_nodes_) { | ||||
|         auto token_or = compiler->GetNodeToken(node_name); | ||||
|         OP_REQUIRES_OK(ctx, token_or.status()); | ||||
|         token_inputs.push_back(token_or.ValueOrDie()); | ||||
|       } | ||||
|       inputs[i] = xla::AfterAll(b, token_inputs); | ||||
|     } else if (ctx->input_type(input_num) == DT_RESOURCE) { | ||||
|       XlaResource* resource; | ||||
|       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); | ||||
|       OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); | ||||
|     } else { | ||||
|       inputs[i] = ctx->Input(i + 1); | ||||
|     } | ||||
|   } | ||||
|   auto input_tuple = xla::Tuple(b, inputs); | ||||
| 
 | ||||
|   xla::XlaOp outputs = | ||||
|       xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations), | ||||
|                        std::vector<xla::XlaOp>(num_branches, input_tuple)); | ||||
|   // Sets non-variable outputs.
 | ||||
|   for (int i = 0; i < output_types_.size(); ++i) { | ||||
|     xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); | ||||
|     if (VLOG_IS_ON(2)) { | ||||
|       LOG(INFO) << "Setting output " << i; | ||||
|       auto shape_or = b->GetShape(output_handle); | ||||
|       if (shape_or.ok()) { | ||||
|         LOG(INFO) << "Shape for output " << i << ": " | ||||
|                   << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); | ||||
|       } else { | ||||
|         LOG(INFO) << "Shape unknown for output " << i; | ||||
|       } | ||||
|     } | ||||
|     ctx->SetOutput(i, output_handle); | ||||
|   } | ||||
|   if (has_token_input_output_) { | ||||
|     // Set token output for this "Case" op. Token output is the last output of
 | ||||
|     // XLA computation, which comes after all "normal" TF outputs and resource
 | ||||
|     // updates. For "Case" node, num of resource updates equals to number of
 | ||||
|     // resource args because we set `return_updated_values_for_all_resources`
 | ||||
|     // to true in XlaCompiler option.
 | ||||
|     xla::XlaOp token_output = | ||||
|         xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); | ||||
|     auto shape_or = b->GetShape(token_output); | ||||
|     OP_REQUIRES_OK(ctx, shape_or.status()); | ||||
|     OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), | ||||
|                 errors::FailedPrecondition( | ||||
|                     "Token output is not token type: ", | ||||
|                     xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); | ||||
|     OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); | ||||
|   } | ||||
| 
 | ||||
|   // Updates the values of any resource variables modified by the conditional
 | ||||
|   // bodies.
 | ||||
|   for (const XlaCompiler::CompilationResult& result : branch_results) { | ||||
|     for (int i = 0; i < result.resource_updates.size(); ++i) { | ||||
|       const XlaCompiler::ResourceUpdate& update = result.resource_updates[i]; | ||||
|       XlaResource* resource; | ||||
|       OP_REQUIRES_OK(ctx, | ||||
|                      ctx->GetResourceInput(update.input_index + 1, &resource)); | ||||
|       if (update.modified) { | ||||
|         int pos = static_cast<int>(result.outputs.size()) + i; | ||||
|         OP_REQUIRES_OK(ctx, | ||||
|                        resource->SetFromPack( | ||||
|                            arguments[update.input_index].tensor_array_gradients, | ||||
|                            xla::GetTupleElement(outputs, pos), b)); | ||||
|       } | ||||
|       VLOG(2) << "Case variable: pos: " << update.input_index | ||||
|               << " name: " << resource->name() | ||||
|               << " modified: " << update.modified | ||||
|               << " type: " << DataTypeString(update.type) | ||||
|               << " shape: " << update.shape.DebugString(); | ||||
|     } | ||||
|   } | ||||
|   VLOG(1) << "Done building Case"; | ||||
| } | ||||
| 
 | ||||
| REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp); | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
							
								
								
									
										62
									
								
								tensorflow/compiler/tf2xla/kernels/case_op.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								tensorflow/compiler/tf2xla/kernels/case_op.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | ||||
| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ | ||||
| #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" | ||||
| #include "tensorflow/core/framework/attr_value.pb.h" | ||||
| #include "tensorflow/core/framework/types.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| // This TensorFlow op provides a functional switch/case primitive.
 | ||||
| //
 | ||||
| // The outputs of the branches must agree on the number, types, and
 | ||||
| // shapes of the Tensors carried around the two bodies.
 | ||||
| //
 | ||||
| // Computations in branch bodies may read from and write to resource variables.
 | ||||
| // Resource variables may be passed as arguments to the branch function's
 | ||||
| // bodies. The XlaCompiler converts resource variable arguments
 | ||||
| // into parameters to the XLA computation and moves them to the end of the
 | ||||
| // parameter list, and by using the `return_updated_values_for_all_variables`
 | ||||
| // we ensure that all variables that appear in the input also appear at the
 | ||||
| // end of the branch bodies output. This ensures the branch bodies output
 | ||||
| // signatures match.
 | ||||
| //
 | ||||
| // It is the user's responsibility to ensure that each non-variable _Arg matches
 | ||||
| // the corresponding _Retval.
 | ||||
| class XlaCaseOp : public XlaOpKernel { | ||||
|  public: | ||||
|   explicit XlaCaseOp(OpKernelConstruction* ctx); | ||||
| 
 | ||||
|   void Compile(XlaOpKernelContext* ctx) override; | ||||
| 
 | ||||
|  private: | ||||
|   TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp); | ||||
| 
 | ||||
|   std::vector<NameAttrList> branches_; | ||||
|   DataTypeVector input_types_; | ||||
|   DataTypeVector output_types_; | ||||
|   bool has_token_input_output_; | ||||
|   std::vector<string> token_input_nodes_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_
 | ||||
| @ -1880,32 +1880,46 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, | ||||
|                               const XlaComputation& true_computation, | ||||
|                               const XlaOp& false_operand, | ||||
|                               const XlaComputation& false_computation) { | ||||
|   // The index of true_computation must be 0 and that of false computation
 | ||||
|   // must be 1.
 | ||||
|   return Conditional(predicate, {&true_computation, &false_computation}, | ||||
|                      {true_operand, false_operand}); | ||||
| } | ||||
| 
 | ||||
| XlaOp XlaBuilder::Conditional( | ||||
|     const XlaOp& branch_index, | ||||
|     absl::Span<const XlaComputation* const> branch_computations, | ||||
|     absl::Span<const XlaOp> branch_operands) { | ||||
|   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { | ||||
|     HloInstructionProto instr; | ||||
| 
 | ||||
|     TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); | ||||
|     TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape, | ||||
|                         GetShape(true_operand)); | ||||
|     TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape, | ||||
|                         true_computation.GetProgramShape()); | ||||
|     TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape, | ||||
|                         GetShape(false_operand)); | ||||
|     TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, | ||||
|                         false_computation.GetProgramShape()); | ||||
|     TF_ASSIGN_OR_RETURN( | ||||
|         Shape shape, | ||||
|         ShapeInference::InferConditionalShape( | ||||
|             predicate_shape, true_operand_shape, false_operand_shape, | ||||
|             true_computation_shape, false_computation_shape)); | ||||
|     TF_ASSIGN_OR_RETURN(const Shape& branch_index_shape, | ||||
|                         GetShape(branch_index)); | ||||
|     std::vector<Shape> branch_operand_shapes(branch_operands.size()); | ||||
|     std::vector<ProgramShape> branch_computation_shapes( | ||||
|         branch_computations.size()); | ||||
|     for (int j = 0; j < branch_operands.size(); ++j) { | ||||
|       TF_ASSIGN_OR_RETURN(branch_operand_shapes[j], | ||||
|                           GetShape(branch_operands[j])); | ||||
|       TF_ASSIGN_OR_RETURN(branch_computation_shapes[j], | ||||
|                           branch_computations[j]->GetProgramShape()); | ||||
|     } | ||||
|     TF_ASSIGN_OR_RETURN(const Shape shape, | ||||
|                         ShapeInference::InferConditionalShape( | ||||
|                             branch_index_shape, branch_computation_shapes, | ||||
|                             branch_operand_shapes)); | ||||
|     *instr.mutable_shape() = shape.ToProto(); | ||||
| 
 | ||||
|     // The index of true_computation must be 0 and that of false computation
 | ||||
|     // must be 1.
 | ||||
|     AddCalledComputation(true_computation, &instr); | ||||
|     AddCalledComputation(false_computation, &instr); | ||||
|     for (const XlaComputation* branch_computation : branch_computations) { | ||||
|       AddCalledComputation(*branch_computation, &instr); | ||||
|     } | ||||
| 
 | ||||
|     std::vector<XlaOp> operands(1, branch_index); | ||||
|     for (const XlaOp branch_operand : branch_operands) { | ||||
|       operands.emplace_back(branch_operand); | ||||
|     } | ||||
|     return AddInstruction(std::move(instr), HloOpcode::kConditional, | ||||
|                           {predicate, true_operand, false_operand}); | ||||
|                           absl::MakeSpan(operands)); | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| @ -3381,6 +3395,13 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, | ||||
|                                           false_computation); | ||||
| } | ||||
| 
 | ||||
| XlaOp Conditional(const XlaOp& branch_index, | ||||
|                   absl::Span<const XlaComputation* const> branch_computations, | ||||
|                   absl::Span<const XlaOp> branch_operands) { | ||||
|   return branch_index.builder()->Conditional(branch_index, branch_computations, | ||||
|                                              branch_operands); | ||||
| } | ||||
| 
 | ||||
| XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, | ||||
|                       const int mantissa_bits) { | ||||
|   return operand.builder()->ReducePrecision(operand, exponent_bits, | ||||
|  | ||||
| @ -529,6 +529,10 @@ class XlaBuilder { | ||||
|                     const XlaOp& false_operand, | ||||
|                     const XlaComputation& false_computation); | ||||
| 
 | ||||
|   XlaOp Conditional(const XlaOp& branch_index, | ||||
|                     absl::Span<const XlaComputation* const> branch_computations, | ||||
|                     absl::Span<const XlaOp> branch_operands); | ||||
| 
 | ||||
|   XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, | ||||
|                         const int mantissa_bits); | ||||
| 
 | ||||
| @ -946,6 +950,10 @@ class XlaBuilder { | ||||
|                            const XlaComputation& true_computation, | ||||
|                            const XlaOp& false_operand, | ||||
|                            const XlaComputation& false_computation); | ||||
|   friend XlaOp Conditional( | ||||
|       const XlaOp& branch_index, | ||||
|       absl::Span<const XlaComputation* const> branch_computations, | ||||
|       absl::Span<const XlaOp> branch_operands); | ||||
|   friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, | ||||
|                                const int mantissa_bits); | ||||
|   friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, | ||||
| @ -1776,6 +1784,15 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, | ||||
|                   const XlaOp& false_operand, | ||||
|                   const XlaComputation& false_computation); | ||||
| 
 | ||||
| // Enqueues either a predicated (if/else) or indexed (switch/case/default)
 | ||||
| // conditional node onto the computation. N >= 1 branch_computations and
 | ||||
| // branch_operands are matched by index. branch_index selects the branch that
 | ||||
| // will be executed. Out of range branch_index uses the N-1'th
 | ||||
| // branch_computation as default.
 | ||||
| XlaOp Conditional(const XlaOp& branch_index, | ||||
|                   absl::Span<const XlaComputation* const> branch_computations, | ||||
|                   absl::Span<const XlaOp> branch_operands); | ||||
| 
 | ||||
| // Enqueues a ReducePrecision node onto the computation.
 | ||||
| XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, | ||||
|                       const int mantissa_bits); | ||||
|  | ||||
| @ -541,25 +541,49 @@ See also | ||||
| false_computation)` </b> | ||||
| 
 | ||||
| Arguments           | Type             | Semantics | ||||
| ------------------- | ---------------- | --------------------------------- | ||||
| ------------------- | ---------------- | -------------------------------------- | ||||
| `pred`              | `XlaOp`          | Scalar of type `PRED` | ||||
| `true_operand`      | `XlaOp`          | Argument of type `T_0` | ||||
| `true_computation`  | `XlaComputation` | XlaComputation of type `T_0 -> S` | ||||
| `false_operand`     | `XlaOp`          | Argument of type `T_1` | ||||
| `false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S` | ||||
| `true_operand`      | `XlaOp`          | Argument of type $$ T_0 $$ | ||||
| `true_computation`  | `XlaComputation` | XlaComputation of type $$ T_0 \to S$$ | ||||
| `false_operand`     | `XlaOp`          | Argument of type $$ T_1 $$ | ||||
| `false_computation` | `XlaComputation` | XlaComputation of type $$ T_1 \to S $$ | ||||
| 
 | ||||
| Executes `true_computation` if `pred` is `true`, `false_computation` if `pred` | ||||
| is `false`, and returns the result. | ||||
| 
 | ||||
| The `true_computation` must take in a single argument of type `T_0` and will be | ||||
| invoked with `true_operand` which must be of the same type. The | ||||
| `false_computation` must take in a single argument of type `T_1` and will be | ||||
| The `true_computation` must take in a single argument of type $$ T_0 $$ and will | ||||
| be invoked with `true_operand` which must be of the same type. The | ||||
| `false_computation` must take in a single argument of type $$ T_1 $$ and will be | ||||
| invoked with `false_operand` which must be of the same type. The type of the | ||||
| returned value of `true_computation` and `false_computation` must be the same. | ||||
| 
 | ||||
| Note that only one of `true_computation` and `false_computation` will be | ||||
| executed depending on the value of `pred`. | ||||
| 
 | ||||
| <b> `Conditional(branch_index, branch_computations, branch_operands)` </b> | ||||
| 
 | ||||
| | Arguments             | Type                  | Semantics                    | | ||||
| | --------------------- | --------------------- | ---------------------------- | | ||||
| | `branch_index`        | `XlaOp`               | Scalar of type `PRED` or     | | ||||
| :                       :                       : `S32`                        : | ||||
| | `branch_computations` | sequence of N         | XlaComputations of type $$   | | ||||
| :                       : `XlaComputation`      : T_0 \to S , T_1 \to S , ..., : | ||||
| :                       :                       : T_{N-1} \to S $$             : | ||||
| | `branch_operands`     | sequence of N `XlaOp` | Arguments of type $$ T_0 ,   | | ||||
| :                       :                       : T_1 , ..., T_{N-1} $$        : | ||||
| 
 | ||||
| Executes `branch_computations[branch_index]`, and returns the result. If | ||||
| `branch_index` is a `PRED`, then the `true` branch is in position 0 and the | ||||
| `false` branch is in position 1. If `branch_index` is an `S32` which is < 0 | ||||
| or >= N, then `branch_computations[N-1]` is executed as the default branch. | ||||
| 
 | ||||
| Each `branch_computations[b]` must take in a single argument of type `T_b` and | ||||
| will be invoked with `branch_operands[b]` which must be of the same type. The | ||||
| type of the returned value of each `branch_computations[b]` must be the same. | ||||
| 
 | ||||
| Note that only one of the `branch_computations` will be executed depending on | ||||
| the value of `branch_index`. | ||||
| 
 | ||||
| ## Conv (convolution) | ||||
| 
 | ||||
| See also | ||||
|  | ||||
| @ -3785,6 +3785,7 @@ cc_library( | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_absl//absl/strings:str_format", | ||||
|         "@com_google_absl//absl/types:span", | ||||
|         "@com_google_absl//absl/types:variant", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| @ -1620,62 +1620,46 @@ void BufferAssigner::BuildColocatedBufferSets( | ||||
|               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); | ||||
|             }); | ||||
|       } else if (opcode == HloOpcode::kConditional) { | ||||
|         const HloInstruction* conditional_hlo = instruction; | ||||
|         const HloInstruction* conditional = instruction; | ||||
|         ShapeUtil::ForEachSubshape( | ||||
|             conditional_hlo->shape(), | ||||
|             [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( | ||||
|             conditional->shape(), | ||||
|             [this, conditional, &points_to_analysis, colocated_buffer_sets]( | ||||
|                 const Shape& /*subshape*/, const ShapeIndex& index) { | ||||
|               std::vector<const LogicalBuffer*> colocated_set; | ||||
|               // Add conditional.result.
 | ||||
|               AddBufferToColocatedSet(conditional_hlo, index, | ||||
|                                       points_to_analysis, &colocated_set); | ||||
|               // Add conditional.true_computation.root.
 | ||||
|               AddBufferToColocatedSet( | ||||
|                   conditional_hlo->true_computation()->root_instruction(), | ||||
|                   index, points_to_analysis, &colocated_set); | ||||
|               // Add conditional.false_computation.root.
 | ||||
|               AddBufferToColocatedSet( | ||||
|                   conditional_hlo->false_computation()->root_instruction(), | ||||
|                   index, points_to_analysis, &colocated_set); | ||||
|               // Add cond.result.
 | ||||
|               AddBufferToColocatedSet(conditional, index, points_to_analysis, | ||||
|                                       &colocated_set); | ||||
|               for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|                 // Add each cond.branch_computation[j].root.
 | ||||
|                 AddBufferToColocatedSet( | ||||
|                     conditional->branch_computation(j)->root_instruction(), | ||||
|                     index, points_to_analysis, &colocated_set); | ||||
|               } | ||||
|               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); | ||||
|             }); | ||||
| 
 | ||||
|         // Add true_operand and conditional.true_computation.parameter(0) as a
 | ||||
|         // colocated buffer set. Note that this has to be done for each subshape
 | ||||
|         // in the true_operand of the conditional.
 | ||||
|         ShapeUtil::ForEachSubshape( | ||||
|             conditional_hlo->operand(1)->shape(), | ||||
|             [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( | ||||
|                 const Shape& /*subshape*/, const ShapeIndex& index) { | ||||
|               std::vector<const LogicalBuffer*> true_set; | ||||
|               // Add conditional.true_operand.
 | ||||
|               AddBufferToColocatedSet(conditional_hlo->operand(1), index, | ||||
|                                       points_to_analysis, &true_set); | ||||
|               // Add conditional.true_computation.parameter_instruction(0).
 | ||||
|               AddBufferToColocatedSet( | ||||
|                   conditional_hlo->true_computation()->parameter_instruction(0), | ||||
|                   index, points_to_analysis, &true_set); | ||||
|               AddSetToColocatedBufferSets(true_set, colocated_buffer_sets); | ||||
|             }); | ||||
| 
 | ||||
|         // Add false_operand and conditional.false_computation.parameter(0) as a
 | ||||
|         // colocated buffer set. Note that this has to be done for each subshape
 | ||||
|         // in the false_operand of the conditional.
 | ||||
|         ShapeUtil::ForEachSubshape( | ||||
|             conditional_hlo->operand(2)->shape(), | ||||
|             [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( | ||||
|                 const Shape& /*subshape*/, const ShapeIndex& index) { | ||||
|               std::vector<const LogicalBuffer*> false_set; | ||||
|               // Add conditional.false_operand.
 | ||||
|               AddBufferToColocatedSet(conditional_hlo->operand(2), index, | ||||
|                                       points_to_analysis, &false_set); | ||||
|               // Add conditional.false_computation.parameter_instruction(0).
 | ||||
|               AddBufferToColocatedSet( | ||||
|                   conditional_hlo->false_computation()->parameter_instruction( | ||||
|                       0), | ||||
|                   index, points_to_analysis, &false_set); | ||||
|               AddSetToColocatedBufferSets(false_set, colocated_buffer_sets); | ||||
|             }); | ||||
|         for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|           // Add branch_operand[j] (which is operand[j+1]) and
 | ||||
|           // cond.branch_computation[j].parameter(0) as a colocated
 | ||||
|           // buffer set. Note that this has to be done for each subshape in the
 | ||||
|           // branch_operand of the case.
 | ||||
|           ShapeUtil::ForEachSubshape( | ||||
|               conditional->operand(j + 1)->shape(), | ||||
|               [this, j, conditional, &points_to_analysis, | ||||
|                colocated_buffer_sets](const Shape& /*subshape*/, | ||||
|                                       const ShapeIndex& index) { | ||||
|                 std::vector<const LogicalBuffer*> branch_set; | ||||
|                 // Add cond.operand[j+1].
 | ||||
|                 AddBufferToColocatedSet(conditional->operand(j + 1), index, | ||||
|                                         points_to_analysis, &branch_set); | ||||
|                 // Add cond.branch_computation[j].parameter_instruction(0).
 | ||||
|                 AddBufferToColocatedSet( | ||||
|                     conditional->branch_computation(j)->parameter_instruction( | ||||
|                         0), | ||||
|                     index, points_to_analysis, &branch_set); | ||||
|                 AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets); | ||||
|               }); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -33,8 +33,8 @@ limitations under the License. | ||||
| namespace xla { | ||||
| 
 | ||||
| // Tries to replace a conditional with a call operation of the corresponding
 | ||||
| // computation. If the given conditional has a constant predicate, tries to
 | ||||
| // replace it with a call to its true/false computation as appropriate and then
 | ||||
| // computation. If the given conditional has a constant branch_index, tries to
 | ||||
| // replace it with a call to its corresponding branch computation and then
 | ||||
| // inline that computation.
 | ||||
| //
 | ||||
| // Returns true if it made a change to the graph.
 | ||||
| @ -50,24 +50,30 @@ static StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) { | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { | ||||
|     VLOG(2) << "Not attempting to remove conditional as its predicate is not a " | ||||
|                "compile-time constant: " | ||||
|             << conditional->ToShortString(); | ||||
|     return false; | ||||
|   } | ||||
|   // We can always inline a 1-branch conditional due to default branch fallback.
 | ||||
|   int branch_index = 0; | ||||
|   if (conditional->branch_count() > 1) { | ||||
|     if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { | ||||
|       VLOG(2) << "Not attempting to remove conditional as its branch_index is " | ||||
|                  "not a compile-time constant: " | ||||
|               << conditional->ToShortString(); | ||||
|       return false; | ||||
|     } | ||||
| 
 | ||||
|     if (conditional->operand(0)->shape().element_type() == PRED) { | ||||
|       branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1; | ||||
|     } else { | ||||
|       branch_index = conditional->operand(0)->literal().Get<int32>({}); | ||||
|       if (branch_index < 0 || branch_index >= conditional->branch_count()) { | ||||
|         branch_index = conditional->branch_count() - 1; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   auto computation = conditional->parent(); | ||||
|   HloInstruction* call_op; | ||||
|   if (conditional->operand(0)->literal().Get<bool>({})) { | ||||
|     call_op = computation->AddInstruction(HloInstruction::CreateCall( | ||||
|         conditional->shape(), {conditional->mutable_operand(1)}, | ||||
|         conditional->true_computation())); | ||||
|   } else { | ||||
|     call_op = computation->AddInstruction(HloInstruction::CreateCall( | ||||
|         conditional->shape(), {conditional->mutable_operand(2)}, | ||||
|         conditional->false_computation())); | ||||
|   } | ||||
|   call_op = computation->AddInstruction(HloInstruction::CreateCall( | ||||
|       conditional->shape(), {conditional->mutable_operand(branch_index + 1)}, | ||||
|       conditional->branch_computation(branch_index))); | ||||
|   conditional->SetupDerivedInstruction(call_op); | ||||
|   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); | ||||
|   TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); | ||||
|  | ||||
| @ -319,8 +319,7 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, | ||||
|           << conditional->name(); | ||||
|   TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); | ||||
| 
 | ||||
|   for (HloComputation* computation : | ||||
|        {conditional->true_computation(), conditional->false_computation()}) { | ||||
|   for (HloComputation* computation : conditional->branch_computations()) { | ||||
|     HloInstruction* root = computation->root_instruction(); | ||||
|     std::vector<HloInstruction*> users = root->users(); | ||||
|     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, | ||||
|  | ||||
| @ -2515,53 +2515,109 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { | ||||
| } | ||||
| 
 | ||||
| Status IrEmitter::HandleConditional(HloInstruction* conditional) { | ||||
|   auto pred = conditional->operand(0); | ||||
|   TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && | ||||
|                pred->shape().element_type() == PRED) | ||||
|       << "Predicate on a Conditional must be bool; got: " | ||||
|       << ShapeUtil::HumanString(pred->shape()); | ||||
|   auto branch_index = conditional->operand(0); | ||||
|   int num_branches = conditional->branch_count(); | ||||
|   TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) && | ||||
|                (branch_index->shape().element_type() == PRED || | ||||
|                 branch_index->shape().element_type() == S32)) | ||||
|       << "Branch index on a conditional must be scalar bool or int32; got: " | ||||
|       << ShapeUtil::HumanString(branch_index->shape()); | ||||
| 
 | ||||
|   HloComputation* true_computation = conditional->true_computation(); | ||||
|   HloComputation* false_computation = conditional->false_computation(); | ||||
|   TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), | ||||
|                                 true_computation->root_instruction()->shape())) | ||||
|       << "Shape of conditional should be same as the shape of the true " | ||||
|       << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) | ||||
|       << " and " | ||||
|       << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); | ||||
| 
 | ||||
|   TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), | ||||
|                                 false_computation->root_instruction()->shape())) | ||||
|       << "Shape of conditional should be same as the shape of the false " | ||||
|       << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) | ||||
|       << " and " | ||||
|       << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); | ||||
|   for (int b = 0; b < num_branches; ++b) { | ||||
|     HloComputation* br_computation = conditional->branch_computation(b); | ||||
|     TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), | ||||
|                                   br_computation->root_instruction()->shape())) | ||||
|         << "Shape of conditional should be same as the shape of the " << b | ||||
|         << "th branch computation; got: " | ||||
|         << ShapeUtil::HumanString(conditional->shape()) << " and " | ||||
|         << ShapeUtil::HumanString(br_computation->root_instruction()->shape()); | ||||
|   } | ||||
| 
 | ||||
|   TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); | ||||
| 
 | ||||
|   // Generating:
 | ||||
|   //   if (pred)
 | ||||
|   //     cond_result = true_computation(true_operand)
 | ||||
|   //   else
 | ||||
|   //     cond_result = false_computation(false_operand)
 | ||||
|   llvm::LoadInst* pred_value = | ||||
|       Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); | ||||
|   llvm::Value* pred_cond = ICmpNE( | ||||
|       pred_value, | ||||
|       llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), | ||||
|       "boolean_predicate"); | ||||
|   llvm_ir::LlvmIfData if_data = | ||||
|       llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); | ||||
|   if (branch_index->shape().element_type() == PRED) { | ||||
|     // Emit an if-else to LLVM:
 | ||||
|     //   if (pred)
 | ||||
|     //     cond_result = true_computation(true_operand)
 | ||||
|     //   else
 | ||||
|     //     cond_result = false_computation(false_operand)
 | ||||
|     llvm::LoadInst* pred_value = Load( | ||||
|         GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value"); | ||||
|     llvm::Value* pred_cond = | ||||
|         ICmpNE(pred_value, | ||||
|                llvm::ConstantInt::get( | ||||
|                    llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), | ||||
|                "boolean_predicate"); | ||||
|     llvm_ir::LlvmIfData if_data = | ||||
|         llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); | ||||
| 
 | ||||
|   SetToFirstInsertPoint(if_data.true_block, &b_); | ||||
|   EmitGlobalCall(*conditional->true_computation(), | ||||
|                  IrName(conditional, "_true")); | ||||
|     SetToFirstInsertPoint(if_data.true_block, &b_); | ||||
|     EmitGlobalCall(*conditional->branch_computation(0), | ||||
|                    IrName(conditional, "_true")); | ||||
| 
 | ||||
|   SetToFirstInsertPoint(if_data.false_block, &b_); | ||||
|   EmitGlobalCall(*conditional->false_computation(), | ||||
|                  IrName(conditional, "_false")); | ||||
|     SetToFirstInsertPoint(if_data.false_block, &b_); | ||||
|     EmitGlobalCall(*conditional->branch_computation(1), | ||||
|                    IrName(conditional, "_false")); | ||||
| 
 | ||||
|   SetToFirstInsertPoint(if_data.after_block, &b_); | ||||
|     SetToFirstInsertPoint(if_data.after_block, &b_); | ||||
|     return Status::OK(); | ||||
|   } | ||||
|   // We emit a switch statement to LLVM:
 | ||||
|   // switch (branch_index) {
 | ||||
|   //   default:
 | ||||
|   //     result = branch_computations[num_branches-1](operands[num_branches-1]);
 | ||||
|   //     break;
 | ||||
|   //   case 0:
 | ||||
|   //     result = branch_computations[0](operands[0]); break;
 | ||||
|   //   case 1:
 | ||||
|   //     result = branch_computations[1](operands[1]); break;
 | ||||
|   //   ...
 | ||||
|   //   case [[num_branches-2]]:
 | ||||
|   //     result = branch_computations[num_branches-2](operands[num_branches-2]);
 | ||||
|   //     break;
 | ||||
|   // }
 | ||||
|   llvm::LoadInst* branch_index_value = Load( | ||||
|       GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value"); | ||||
| 
 | ||||
|   auto case_block = b_.GetInsertBlock(); | ||||
|   llvm::BasicBlock* after_block; | ||||
|   // Add a terminator to the case block, if necessary.
 | ||||
|   if (case_block->getTerminator() == nullptr) { | ||||
|     after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_); | ||||
|     b_.SetInsertPoint(case_block); | ||||
|     b_.CreateBr(after_block); | ||||
|   } else { | ||||
|     after_block = | ||||
|         case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after"); | ||||
|   } | ||||
|   // Our basic block should now end with an unconditional branch.  Remove it;
 | ||||
|   // we're going to replace it with a switch based branch.
 | ||||
|   case_block->getTerminator()->eraseFromParent(); | ||||
| 
 | ||||
|   // Lower the default branch computation.
 | ||||
|   auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_); | ||||
|   b_.SetInsertPoint(default_block); | ||||
|   EmitGlobalCall(*conditional->branch_computation(num_branches - 1), | ||||
|                  IrName(conditional, "_default")); | ||||
|   b_.CreateBr(after_block); | ||||
| 
 | ||||
|   // Prepare the switch (branch_index) { ... } instruction.
 | ||||
|   b_.SetInsertPoint(case_block); | ||||
|   llvm::SwitchInst* case_inst = | ||||
|       b_.CreateSwitch(branch_index_value, default_block, num_branches - 1); | ||||
|   // Lower each branch's computation.
 | ||||
|   for (int b = 0; b < num_branches - 1; ++b) {  // last branch is default
 | ||||
|     // Lower the case b: { ... ; break; } computation.
 | ||||
|     auto branch_block = | ||||
|         llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_); | ||||
|     b_.SetInsertPoint(branch_block); | ||||
|     EmitGlobalCall(*conditional->branch_computation(b), | ||||
|                    IrName(conditional, absl::StrCat("_branch", b))); | ||||
|     b_.CreateBr(after_block); | ||||
|     case_inst->addCase(b_.getInt32(b), branch_block); | ||||
|   } | ||||
| 
 | ||||
|   SetToFirstInsertPoint(after_block, &b_); | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -26,7 +26,7 @@ namespace xla { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| // Helper to replace the called computation at a while-, call-, or
 | ||||
| // Helper to replace the called computation at a while-, call-, case-, or
 | ||||
| // conditional-instruction. This function replaces exactly one instance of
 | ||||
| // 'computation' with 'new_computation' even if 'instruction' calls
 | ||||
| // 'computation' more than once.
 | ||||
| @ -49,11 +49,14 @@ void ReplaceCalledComputation(HloInstruction* instruction, | ||||
|       break; | ||||
|     } | ||||
|     case HloOpcode::kConditional: { | ||||
|       if (computation == instruction->true_computation()) { | ||||
|         instruction->set_true_computation(new_computation); | ||||
|       } else { | ||||
|         CHECK_EQ(computation, instruction->false_computation()); | ||||
|         instruction->set_false_computation(new_computation); | ||||
|       for (int b = 0; b < instruction->branch_count(); ++b) { | ||||
|         if (b == instruction->branch_count() - 1) { | ||||
|           CHECK_EQ(computation, instruction->branch_computation(b)); | ||||
|         } | ||||
|         if (computation == instruction->branch_computation(b)) { | ||||
|           instruction->set_branch_computation(b, new_computation); | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       break; | ||||
|     } | ||||
|  | ||||
| @ -24,25 +24,35 @@ namespace xla { | ||||
| namespace gpu { | ||||
| 
 | ||||
| ConditionalThunk::ConditionalThunk( | ||||
|     const BufferAllocation::Slice& predicate_buffer_index, | ||||
|     const BufferAllocation::Slice& true_operand_buffer_index, | ||||
|     const BufferAllocation::Slice& false_operand_buffer_index, | ||||
|     ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence, | ||||
|     const BufferAllocation::Slice& branch_index_buffer_index, | ||||
|     absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes, | ||||
|     std::vector<ThunkSequence> branch_thunk_sequences, | ||||
|     const HloInstruction* hlo) | ||||
|     : Thunk(Kind::kConditional, hlo), | ||||
|       predicate_buffer_index_(predicate_buffer_index), | ||||
|       true_operand_buffer_index_(true_operand_buffer_index), | ||||
|       false_operand_buffer_index_(false_operand_buffer_index), | ||||
|       // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_
 | ||||
|       // constructors because these SequentialThunks are logically "part of"
 | ||||
|       // this ConditionalThunk, and shouldn't be profiled separately from it.
 | ||||
|       true_thunk_(std::move(true_thunk_sequence), nullptr), | ||||
|       false_thunk_(std::move(false_thunk_sequence), nullptr) {} | ||||
|       branch_index_is_bool_(hlo->operand(0)->shape().element_type() == PRED), | ||||
|       branch_index_buffer_index_(branch_index_buffer_index), | ||||
|       branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), | ||||
|                                      branch_operand_buffer_indexes.end()) { | ||||
|   // Pass nullptr as the HloInstruction* to the branch_thunks_
 | ||||
|   // constructors because these SequentialThunks are logically "part of"
 | ||||
|   // this ConditionalThunk, and shouldn't be profiled separately from it.
 | ||||
|   branch_thunks_.reserve(branch_thunk_sequences.size()); | ||||
|   for (auto& branch_thunk_sequence : branch_thunk_sequences) { | ||||
|     branch_thunks_.emplace_back( | ||||
|         new SequentialThunk(std::move(branch_thunk_sequence), nullptr)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| Status ConditionalThunk::Initialize(const GpuExecutable& executable, | ||||
|                                     se::StreamExecutor* executor) { | ||||
|   TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); | ||||
|   TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); | ||||
|   if (branch_index_is_bool_) { | ||||
|     TF_RET_CHECK(branch_thunks_.size() == 2); | ||||
|   } else { | ||||
|     TF_RET_CHECK(!branch_thunks_.empty()); | ||||
|   } | ||||
|   for (auto& branch_thunk : branch_thunks_) { | ||||
|     TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor)); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| @ -51,30 +61,37 @@ Status ConditionalThunk::ExecuteOnStream( | ||||
|     HloExecutionProfiler* profiler) { | ||||
|   auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); | ||||
|   // Copy the predicate value from device.
 | ||||
|   bool predicate; | ||||
|   se::DeviceMemoryBase predicate_address = | ||||
|       buffer_allocations.GetDeviceAddress(predicate_buffer_index_); | ||||
|   stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool)); | ||||
|   int32 branch_index = -1; | ||||
|   bool pred = false; | ||||
|   se::DeviceMemoryBase branch_index_address = | ||||
|       buffer_allocations.GetDeviceAddress(branch_index_buffer_index_); | ||||
|   if (branch_index_is_bool_) { | ||||
|     stream->ThenMemcpy(&pred, branch_index_address, sizeof(bool)); | ||||
|   } else { | ||||
|     stream->ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); | ||||
|   } | ||||
| 
 | ||||
|   Status block_status = stream->BlockHostUntilDone(); | ||||
|   if (!block_status.ok()) { | ||||
|     return InternalError("Failed to retrieve predicate value on stream %p: %s.", | ||||
|                          stream, block_status.error_message()); | ||||
|     return InternalError( | ||||
|         "Failed to retrieve branch_index value on stream %p: %s.", stream, | ||||
|         block_status.error_message()); | ||||
|   } | ||||
|   if (branch_index_is_bool_) { | ||||
|     branch_index = pred ? 0 : 1; | ||||
|   } else { | ||||
|     // Handle default scenario for branch_index not in [0, num_branches).
 | ||||
|     if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) { | ||||
|       branch_index = hlo_instruction()->branch_count() - 1; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Execute the true or the false computation depending on the value of the
 | ||||
|   // predicate.
 | ||||
|   if (predicate) { | ||||
|     profiler->StartHloComputation(); | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); | ||||
|     profiler->FinishHloComputation(hlo_instruction()->true_computation()); | ||||
|   } else { | ||||
|     profiler->StartHloComputation(); | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); | ||||
|     profiler->FinishHloComputation(hlo_instruction()->false_computation()); | ||||
|   } | ||||
|   // Execute the branch computation corresponding to the value of branch_index.
 | ||||
|   profiler->StartHloComputation(); | ||||
|   TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream( | ||||
|       buffer_allocations, stream, profiler)); | ||||
|   profiler->FinishHloComputation( | ||||
|       hlo_instruction()->branch_computation(branch_index)); | ||||
| 
 | ||||
|   return Status::OK(); | ||||
| } | ||||
|  | ||||
| @ -16,6 +16,10 @@ limitations under the License. | ||||
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ | ||||
| #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/types/span.h" | ||||
| #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" | ||||
| #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" | ||||
| #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" | ||||
| @ -38,12 +42,11 @@ namespace gpu { | ||||
| // false computation share the same allocation.
 | ||||
| class ConditionalThunk : public Thunk { | ||||
|  public: | ||||
|   ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index, | ||||
|                    const BufferAllocation::Slice& true_operand_buffer_index, | ||||
|                    const BufferAllocation::Slice& false_operand_buffer_index, | ||||
|                    ThunkSequence true_thunk_sequence, | ||||
|                    ThunkSequence false_thunk_sequence, | ||||
|                    const HloInstruction* hlo); | ||||
|   ConditionalThunk( | ||||
|       const BufferAllocation::Slice& branch_index_buffer_index, | ||||
|       absl::Span<const BufferAllocation::Slice> branch_operand_buffer_indexes, | ||||
|       std::vector<ThunkSequence> branch_thunk_sequences, | ||||
|       const HloInstruction* hlo); | ||||
| 
 | ||||
|   ConditionalThunk(const ConditionalThunk&) = delete; | ||||
|   ConditionalThunk& operator=(const ConditionalThunk&) = delete; | ||||
| @ -55,11 +58,10 @@ class ConditionalThunk : public Thunk { | ||||
|                          HloExecutionProfiler* profiler) override; | ||||
| 
 | ||||
|  private: | ||||
|   BufferAllocation::Slice predicate_buffer_index_; | ||||
|   BufferAllocation::Slice true_operand_buffer_index_; | ||||
|   BufferAllocation::Slice false_operand_buffer_index_; | ||||
|   SequentialThunk true_thunk_; | ||||
|   SequentialThunk false_thunk_; | ||||
|   const bool branch_index_is_bool_; | ||||
|   BufferAllocation::Slice branch_index_buffer_index_; | ||||
|   std::vector<BufferAllocation::Slice> branch_operand_buffer_indexes_; | ||||
|   std::vector<std::unique_ptr<SequentialThunk>> branch_thunks_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace gpu
 | ||||
|  | ||||
| @ -2024,41 +2024,32 @@ Status CheckWhileBuffersShareAllocation( | ||||
| // Checks that the buffers used in a conditional instruction are shared with the
 | ||||
| // operands and result as follows:
 | ||||
| //   * The result buffer of the conditional should share the allocation with the
 | ||||
| //     result buffers of the true and false computations.
 | ||||
| //   * The buffer of operand 1 should share the allocation with the buffer of
 | ||||
| //     the parameter 0 instruction of the true computation.
 | ||||
| //   * The buffer of operand 2 should share the allocation with the buffer of
 | ||||
| //     the parameter 0 instruction of the false computation.
 | ||||
| //     result buffers of each branch computation.
 | ||||
| //   * The buffer of operand b+1 should share the allocation with the buffer of
 | ||||
| //     the parameter 0 instruction of the b'th computation.
 | ||||
| Status CheckConditionalBuffersShareAllocation( | ||||
|     const HloInstruction* conditional, | ||||
|     const BufferAssignment& buffer_assignment) { | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( | ||||
|       conditional->shape(), | ||||
|       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { | ||||
|         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( | ||||
|             conditional, conditional->true_computation()->root_instruction(), | ||||
|             index, buffer_assignment)); | ||||
|         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( | ||||
|             conditional, conditional->false_computation()->root_instruction(), | ||||
|             index, buffer_assignment)); | ||||
|         for (auto branch_computation : conditional->branch_computations()) { | ||||
|           TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( | ||||
|               conditional, branch_computation->root_instruction(), index, | ||||
|               buffer_assignment)); | ||||
|         } | ||||
|         return Status::OK(); | ||||
|       })); | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( | ||||
|       conditional->operand(1)->shape(), | ||||
|       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { | ||||
|         return CheckHloBuffersShareAllocation( | ||||
|             conditional->operand(1), | ||||
|             conditional->true_computation()->parameter_instruction(0), index, | ||||
|             buffer_assignment); | ||||
|       })); | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( | ||||
|       conditional->operand(2)->shape(), | ||||
|       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { | ||||
|         return CheckHloBuffersShareAllocation( | ||||
|             conditional->operand(2), | ||||
|             conditional->false_computation()->parameter_instruction(0), index, | ||||
|             buffer_assignment); | ||||
|       })); | ||||
|   for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( | ||||
|         conditional->operand(j + 1)->shape(), | ||||
|         [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { | ||||
|           return CheckHloBuffersShareAllocation( | ||||
|               conditional->operand(j + 1), | ||||
|               conditional->branch_computation(j)->parameter_instruction(0), | ||||
|               index, buffer_assignment); | ||||
|         })); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| @ -2111,22 +2102,20 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( | ||||
|   TF_CHECK_OK(CheckConditionalBuffersShareAllocation( | ||||
|       hlo, ir_emitter_context_->buffer_assignment())); | ||||
| 
 | ||||
|   HloComputation* true_computation = hlo->true_computation(); | ||||
|   IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, | ||||
|                                     ir_emitter_context_); | ||||
|   TF_CHECK_OK(true_computation->Accept(&ir_emitter_true)); | ||||
| 
 | ||||
|   HloComputation* false_computation = hlo->false_computation(); | ||||
|   IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, | ||||
|                                      ir_emitter_context_); | ||||
|   TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); | ||||
|   std::vector<BufferAllocation::Slice> branch_operands; | ||||
|   std::vector<ThunkSequence> branch_thunks; | ||||
|   for (int j = 0; j < hlo->branch_count(); ++j) { | ||||
|     branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); | ||||
|     HloComputation* branch_computation = hlo->branch_computation(j); | ||||
|     IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation, | ||||
|                                  ir_emitter_context_); | ||||
|     TF_CHECK_OK(branch_computation->Accept(&ir_emitter)); | ||||
|     branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence())); | ||||
|   } | ||||
| 
 | ||||
|   return absl::make_unique<ConditionalThunk>( | ||||
|       GetAllocationSlice(*hlo->operand(0)), | ||||
|       GetAllocationSlice(*hlo->operand(1)), | ||||
|       GetAllocationSlice(*hlo->operand(2)), | ||||
|       std::move(*ir_emitter_true.ConsumeThunkSequence()), | ||||
|       std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo); | ||||
|       GetAllocationSlice(*hlo->operand(0)), branch_operands, | ||||
|       std::move(branch_thunks), hlo); | ||||
| } | ||||
| 
 | ||||
| Status IrEmitterUnnested::EmitTargetElementLoopInThunk( | ||||
|  | ||||
| @ -356,9 +356,9 @@ class IrEmitterUnnested : public IrEmitter { | ||||
|   std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo, | ||||
|                                        const int64 loop_limit); | ||||
| 
 | ||||
|   // Returns a ConditionalThunk that executes the thunk sequence for
 | ||||
|   // 'true_computation' or 'false_computation' depending on the value of the
 | ||||
|   // predicate in the given conditional instruction.
 | ||||
|   // Returns a ConditionalThunk which executes the thunk sequence for the
 | ||||
|   // 'branch_computation' corresponding to the predicate/branch_index of the
 | ||||
|   // given conditional instruction.
 | ||||
|   std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo); | ||||
| 
 | ||||
|   Status Postprocess(HloInstruction* hlo) override; | ||||
|  | ||||
| @ -293,7 +293,7 @@ class BufferValueMap { | ||||
|             VLOG(3) | ||||
|                 << "  value @ " << position << " is root of " | ||||
|                 << callsite.instruction()->name() | ||||
|                 << "; true/false branch roots must share buffer among them : " | ||||
|                 << "; branch computation roots must share buffer among them : " | ||||
|                 << cond_value.ToShortString(); | ||||
|             aliased_buffers->push_back(GetBufferForValue(cond_value)); | ||||
|           } | ||||
|  | ||||
| @ -684,19 +684,22 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { | ||||
| } | ||||
| 
 | ||||
| Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { | ||||
|   // Compute the cost of the true and false computations and take the maximum
 | ||||
|   // from those for each property.
 | ||||
|   // Compute the cost of the branch computations and take the maximum from those
 | ||||
|   // for each property.
 | ||||
|   TF_ASSIGN_OR_RETURN( | ||||
|       const Properties true_computation_properties, | ||||
|       ProcessUnnestedSubcomputation(conditional->true_computation())); | ||||
|   TF_ASSIGN_OR_RETURN( | ||||
|       const Properties false_computation_properties, | ||||
|       ProcessUnnestedSubcomputation(conditional->false_computation())); | ||||
|   current_properties_ = true_computation_properties; | ||||
|   for (const auto& property : false_computation_properties) { | ||||
|     if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { | ||||
|       current_properties_[property.first] = | ||||
|           std::max(current_properties_[property.first], property.second); | ||||
|       const Properties branch0_computation_properties, | ||||
|       ProcessUnnestedSubcomputation(conditional->branch_computation(0))); | ||||
|   current_properties_ = branch0_computation_properties; | ||||
|   for (int j = 1; j < conditional->branch_count(); ++j) { | ||||
|     TF_ASSIGN_OR_RETURN( | ||||
|         const Properties branch_computation_properties, | ||||
|         ProcessUnnestedSubcomputation(conditional->branch_computation(j))); | ||||
|     for (const auto& property : branch_computation_properties) { | ||||
|       if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, | ||||
|                                                property)) { | ||||
|         auto& current_property = current_properties_[property.first]; | ||||
|         current_property = std::max(current_property, property.second); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   current_should_compute_bottleneck_time_ = false; | ||||
|  | ||||
| @ -414,11 +414,11 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { | ||||
| bool HloDataflowAnalysis::UpdateConditionalValueSet( | ||||
|     HloInstruction* conditional) { | ||||
|   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); | ||||
|   const InstructionValueSet* const inputs[] = { | ||||
|       &GetInstructionValueSet( | ||||
|           conditional->true_computation()->root_instruction()), | ||||
|       &GetInstructionValueSet( | ||||
|           conditional->false_computation()->root_instruction())}; | ||||
|   std::vector<const InstructionValueSet*> inputs(conditional->branch_count()); | ||||
|   for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|     inputs[j] = &GetInstructionValueSet( | ||||
|         conditional->branch_computation(j)->root_instruction()); | ||||
|   } | ||||
|   if (ssa_form_) { | ||||
|     return Phi(conditional, inputs); | ||||
|   } else { | ||||
| @ -546,20 +546,23 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { | ||||
|     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { | ||||
|       CHECK_EQ(parameter->parameter_number(), 0); | ||||
|       auto conditional = callsite.instruction(); | ||||
|       // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is
 | ||||
|       // the argument to the true computation and operand 2 is the argument to
 | ||||
|       // the false computation.
 | ||||
|       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
 | ||||
|       // operands 1 and onward are the arguments to the branch computations.
 | ||||
|       //
 | ||||
|       // If the parameter belongs to conditional's true computation, then
 | ||||
|       // If the parameter belongs to conditional's branch 0 computation, then
 | ||||
|       // operand 1 is forwarded to this parameter instruction. If the parameter
 | ||||
|       // belongs to conditional's false computation, then operand 2 is forwarded
 | ||||
|       // to this parameter instruction.
 | ||||
|       if (parameter->parent() == conditional->true_computation()) { | ||||
|         inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); | ||||
|       } else { | ||||
|         CHECK_EQ(parameter->parent(), conditional->false_computation()); | ||||
|         inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); | ||||
|       // belongs to conditional's branch 5 computation, then operand 6 is
 | ||||
|       // forwarded to this parameter instruction.
 | ||||
|       bool found_parent = false; | ||||
|       for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|         if (parameter->parent() == conditional->branch_computation(j)) { | ||||
|           inputs.push_back( | ||||
|               &GetInstructionValueSet(conditional->operand(j + 1))); | ||||
|           found_parent = true; | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|       CHECK(found_parent); | ||||
|       need_phi = true; | ||||
|     } else { | ||||
|       LOG(FATAL) << "CallContext::kSequential computations should only be " | ||||
| @ -710,19 +713,17 @@ void HloDataflowAnalysis::Propagate() { | ||||
|       // parameter(s) of the computation need to be updated.
 | ||||
|       if (user->opcode() == HloOpcode::kConditional) { | ||||
|         // If operand 0 is the use of instruction, then no parameters need to be
 | ||||
|         // updated, since that is the predicate of the conditional.
 | ||||
|         // If operand 1 is the use of instruction, then the true_computation's
 | ||||
|         // parameter need to be updated.
 | ||||
|         // If operand 2 is the use of instruction, then the false_computation's
 | ||||
|         // parameter need to be updated.
 | ||||
|         // updated, since that is the branch_index of the conditional.
 | ||||
|         // If operand n+1 is the use of instruction, then the branch_computation
 | ||||
|         // n's parameter need to be updated.
 | ||||
|         //
 | ||||
|         // Note that the same instruction can be used in both operand 1 and
 | ||||
|         // operand 2.
 | ||||
|         if (user->operand(1) == instruction) { | ||||
|           add_to_worklist(user->true_computation()->parameter_instruction(0)); | ||||
|         } | ||||
|         if (user->operand(2) == instruction) { | ||||
|           add_to_worklist(user->false_computation()->parameter_instruction(0)); | ||||
|         // Note that the same instruction can be used in multiple branches'
 | ||||
|         // operands.
 | ||||
|         for (int j = 0; j < user->branch_count(); ++j) { | ||||
|           if (user->operand(j + 1) == instruction) { | ||||
|             add_to_worklist( | ||||
|                 user->branch_computation(j)->parameter_instruction(0)); | ||||
|           } | ||||
|         } | ||||
|       } else { | ||||
|         for (HloComputation* called_computation : user->called_computations()) { | ||||
| @ -744,8 +745,8 @@ void HloDataflowAnalysis::Propagate() { | ||||
|       const CallGraphNode& call_graph_node = | ||||
|           call_graph_->GetNode(instruction->parent()); | ||||
|       for (const CallSite& callsite : call_graph_node.caller_callsites()) { | ||||
|         if ((callsite.instruction()->opcode() == HloOpcode::kCall) || | ||||
|             (callsite.instruction()->opcode() == HloOpcode::kConditional)) { | ||||
|         if (callsite.instruction()->opcode() == HloOpcode::kCall || | ||||
|             callsite.instruction()->opcode() == HloOpcode::kConditional) { | ||||
|           add_to_worklist(callsite.instruction()); | ||||
|         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { | ||||
|           // Add the while itself, and the body and condition parameters.
 | ||||
|  | ||||
| @ -1270,28 +1270,27 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { | ||||
| } | ||||
| 
 | ||||
| Status HloEvaluator::HandleConditional(HloInstruction* conditional) { | ||||
|   const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); | ||||
|   const auto& true_computation_arg = | ||||
|       GetEvaluatedLiteralFor(conditional->operand(1)); | ||||
|   const auto& false_computation_arg = | ||||
|       GetEvaluatedLiteralFor(conditional->operand(2)); | ||||
| 
 | ||||
|   auto* true_computation = conditional->true_computation(); | ||||
|   auto* false_computation = conditional->false_computation(); | ||||
|   const auto& branch_index_literal = | ||||
|       GetEvaluatedLiteralFor(conditional->operand(0)); | ||||
|   int branch_index; | ||||
|   if (conditional->operand(0)->shape().element_type() == PRED) { | ||||
|     branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1; | ||||
|   } else { | ||||
|     branch_index = branch_index_literal.Get<int32>({}); | ||||
|     if (branch_index < 0 || branch_index >= conditional->branch_count()) { | ||||
|       branch_index = conditional->branch_count() - 1; | ||||
|     } | ||||
|   } | ||||
|   const auto& branch_computation_arg = | ||||
|       GetEvaluatedLiteralFor(conditional->operand(1 + branch_index)); | ||||
| 
 | ||||
|   HloEvaluator embedded_evaluator; | ||||
|   embedded_evaluator.set_dynamic_dimension_inference( | ||||
|       dynamic_dimension_inference_); | ||||
|   Literal result; | ||||
|   if (pred.Get<bool>({})) { | ||||
|     result = | ||||
|         embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg}) | ||||
|             .ConsumeValueOrDie(); | ||||
|   } else { | ||||
|     result = embedded_evaluator | ||||
|                  .Evaluate(*false_computation, {&false_computation_arg}) | ||||
|                  .ConsumeValueOrDie(); | ||||
|   } | ||||
|   Literal result = embedded_evaluator | ||||
|                        .Evaluate(*conditional->branch_computation(branch_index), | ||||
|                                  {&branch_computation_arg}) | ||||
|                        .ConsumeValueOrDie(); | ||||
| 
 | ||||
|   evaluated_[conditional] = std::move(result); | ||||
|   return Status::OK(); | ||||
|  | ||||
| @ -82,6 +82,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( | ||||
|   const auto computations = [&computation_map, &proto](int index) { | ||||
|     return computation_map.at(proto.called_computation_ids(index)); | ||||
|   }; | ||||
|   const auto all_computations = [&computation_map, &proto]() { | ||||
|     std::vector<HloComputation*> result(proto.called_computation_ids_size()); | ||||
|     std::transform(proto.called_computation_ids().begin(), | ||||
|                    proto.called_computation_ids().end(), result.begin(), | ||||
|                    [&computation_map](int64 computation_id) { | ||||
|                      return computation_map.at(computation_id); | ||||
|                    }); | ||||
|     return result; | ||||
|   }; | ||||
| 
 | ||||
|   TF_RET_CHECK( | ||||
|       absl::c_all_of(proto.operand_ids(), | ||||
| @ -163,6 +172,26 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( | ||||
|       instruction = | ||||
|           CreateConcatenate(shape, all_operands(), proto.dimensions(0)); | ||||
|       break; | ||||
|     case HloOpcode::kConditional: { | ||||
|       TF_RET_CHECK(proto.called_computation_ids_size() > 0) | ||||
|           << "conditional should have at least 1 called computation"; | ||||
|       if (operands(0)->shape().element_type() == PRED) { | ||||
|         TF_RET_CHECK(proto.called_computation_ids_size() == 2) | ||||
|             << "conditional should have exactly 2 called computations but got " | ||||
|             << proto.called_computation_ids_size(); | ||||
|       } | ||||
|       TF_RET_CHECK(proto.operand_ids_size() == | ||||
|                    proto.called_computation_ids_size() + 1) | ||||
|           << "conditional should have one branch_index operand plus one " | ||||
|              "operand per called computation but got " | ||||
|           << proto.operand_ids_size() << " operands for " | ||||
|           << proto.called_computation_ids_size() << " branch computations"; | ||||
|       auto cond_operands = all_operands(); | ||||
|       instruction = | ||||
|           CreateConditional(shape, cond_operands[0], all_computations(), | ||||
|                             absl::MakeSpan(cond_operands).subspan(1)); | ||||
|       break; | ||||
|     } | ||||
|     case HloOpcode::kReduce: | ||||
|       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) | ||||
|           << "Reduce instruction should have an even number of operands but " | ||||
| @ -898,6 +927,21 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, | ||||
|   return instruction; | ||||
| } | ||||
| 
 | ||||
| /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional( | ||||
|     const Shape& shape, HloInstruction* branch_index, | ||||
|     absl::Span<HloComputation* const> branch_computations, | ||||
|     absl::Span<HloInstruction* const> branch_computation_args) { | ||||
|   auto instruction = | ||||
|       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); | ||||
|   instruction->AppendOperand(branch_index); | ||||
|   CHECK_EQ(branch_computations.size(), branch_computation_args.size()); | ||||
|   for (int i = 0; i < branch_computations.size(); ++i) { | ||||
|     instruction->called_computations_.push_back(branch_computations[i]); | ||||
|     instruction->AppendOperand(branch_computation_args[i]); | ||||
|   } | ||||
|   return instruction; | ||||
| } | ||||
| 
 | ||||
| /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice( | ||||
|     const Shape& shape, HloInstruction* operand, | ||||
|     absl::Span<const int64> start_indices, | ||||
| @ -1397,10 +1441,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( | ||||
|           CreateWhile(shape, while_condition(), while_body(), new_operands[0]); | ||||
|       break; | ||||
|     case HloOpcode::kConditional: | ||||
|       CHECK_EQ(new_operands.size(), 3); | ||||
|       clone = CreateConditional(shape, new_operands[0], new_operands[1], | ||||
|                                 true_computation(), new_operands[2], | ||||
|                                 false_computation()); | ||||
|       CHECK_EQ(new_operands.size(), branch_count() + 1); | ||||
|       clone = CreateConditional(shape, new_operands[0], | ||||
|                                 absl::MakeSpan(branch_computations()), | ||||
|                                 new_operands.subspan(1)); | ||||
|       break; | ||||
|     case HloOpcode::kAfterAll: | ||||
|       if (new_operands.empty()) { | ||||
| @ -1711,16 +1755,16 @@ bool HloInstruction::IdenticalSlowPath( | ||||
|     case HloOpcode::kCall: | ||||
|       return eq_computations(to_apply(), other.to_apply()); | ||||
|     case HloOpcode::kConditional: | ||||
|       return eq_computations(true_computation(), other.true_computation()) && | ||||
|              eq_computations(false_computation(), other.false_computation()); | ||||
| 
 | ||||
|     case HloOpcode::kWhile: { | ||||
|       if (eq_computations(while_body(), other.while_body()) && | ||||
|           eq_computations(while_condition(), other.while_condition())) { | ||||
|         return true; | ||||
|       for (int j = 0; j < branch_count(); ++j) { | ||||
|         if (!eq_computations(branch_computation(j), | ||||
|                              other.branch_computation(j))) { | ||||
|           return false; | ||||
|         } | ||||
|       } | ||||
|       return false; | ||||
|     } | ||||
|       return true; | ||||
|     case HloOpcode::kWhile: | ||||
|       return (eq_computations(while_body(), other.while_body()) && | ||||
|               eq_computations(while_condition(), other.while_condition())); | ||||
| 
 | ||||
|     // Ops migrated to subclasses should never come to this line.
 | ||||
|     // TODO(b/80131774): Remove this switch when migration is complete.
 | ||||
| @ -1983,28 +2027,41 @@ HloInstruction* HloInstruction::while_init() const { | ||||
| 
 | ||||
| HloComputation* HloInstruction::true_computation() const { | ||||
|   CHECK_EQ(HloOpcode::kConditional, opcode_); | ||||
|   CHECK_EQ(PRED, operand(0)->shape().element_type()); | ||||
|   return called_computations_[kTrueComputationIndex]; | ||||
| } | ||||
| 
 | ||||
| HloComputation* HloInstruction::false_computation() const { | ||||
|   CHECK_EQ(HloOpcode::kConditional, opcode_); | ||||
|   CHECK_EQ(PRED, operand(0)->shape().element_type()); | ||||
|   return called_computations_[kFalseComputationIndex]; | ||||
| } | ||||
| 
 | ||||
| void HloInstruction::set_true_computation(HloComputation* true_computation) { | ||||
|   // Don't allow changing the computation for fused instructions so we don't
 | ||||
|   // have to recompute called_instructions for the entire fusion instruction.
 | ||||
|   CHECK(!IsFused()); | ||||
|   CHECK_EQ(HloOpcode::kConditional, opcode_); | ||||
|   called_computations_[kTrueComputationIndex] = true_computation; | ||||
| const std::vector<HloComputation*>& HloInstruction::branch_computations() | ||||
|     const { | ||||
|   CHECK(HloOpcode::kConditional == opcode_); | ||||
|   return called_computations_; | ||||
| } | ||||
| 
 | ||||
| void HloInstruction::set_false_computation(HloComputation* false_computation) { | ||||
| int HloInstruction::branch_count() const { | ||||
|   CHECK(HloOpcode::kConditional == opcode_); | ||||
|   return called_computations_.size(); | ||||
| } | ||||
| 
 | ||||
| HloComputation* HloInstruction::branch_computation(int b) const { | ||||
|   CHECK(HloOpcode::kConditional == opcode_); | ||||
|   CHECK_GE(b, 0); | ||||
|   CHECK_LT(b, called_computations_.size()); | ||||
|   return called_computations_[b]; | ||||
| } | ||||
| 
 | ||||
| void HloInstruction::set_branch_computation(int b, | ||||
|                                             HloComputation* computation) { | ||||
|   // Don't allow changing the computation for fused instructions so we don't
 | ||||
|   // have to recompute called_instructions for the entire fusion instruction.
 | ||||
|   CHECK(!IsFused()); | ||||
|   CHECK_EQ(HloOpcode::kConditional, opcode_); | ||||
|   called_computations_[kFalseComputationIndex] = false_computation; | ||||
|   called_computations_[b] = computation; | ||||
| } | ||||
| 
 | ||||
| string HloInstruction::SignatureString() const { | ||||
| @ -2207,10 +2264,21 @@ std::vector<string> HloInstruction::ExtraAttributesToString( | ||||
|       extra.push_back( | ||||
|           StrCat("scatter=", PrintName(scatter()->name(), options))); | ||||
|     } else if (opcode() == HloOpcode::kConditional) { | ||||
|       extra.push_back(StrCat("true_computation=", | ||||
|                              PrintName(true_computation()->name(), options))); | ||||
|       extra.push_back(StrCat("false_computation=", | ||||
|                              PrintName(false_computation()->name(), options))); | ||||
|       if (operand(0)->shape().element_type() == PRED) { | ||||
|         extra.push_back(StrCat("true_computation=", | ||||
|                                PrintName(true_computation()->name(), options))); | ||||
|         extra.push_back( | ||||
|             StrCat("false_computation=", | ||||
|                    PrintName(false_computation()->name(), options))); | ||||
|       } else { | ||||
|         extra.push_back(StrCat( | ||||
|             "branch_computations={", | ||||
|             StrJoin(branch_computations(), ", ", | ||||
|                     [&](string* out, const HloComputation* computation) { | ||||
|                       StrAppend(out, PrintName(computation->name(), options)); | ||||
|                     }), | ||||
|             "}")); | ||||
|       } | ||||
|     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || | ||||
|                opcode() == HloOpcode::kReduceWindow || | ||||
|                opcode() == HloOpcode::kReduce || | ||||
| @ -2242,10 +2310,20 @@ std::vector<string> HloInstruction::ExtraAttributesToString( | ||||
|         extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); | ||||
|         break; | ||||
|       case HloOpcode::kConditional: | ||||
|         extra.push_back(StrCat("true_computation=\n", | ||||
|                                true_computation()->ToString(new_options))); | ||||
|         extra.push_back(StrCat("false_computation=\n", | ||||
|                                false_computation()->ToString(new_options))); | ||||
|         if (operand(0)->shape().element_type() == PRED) { | ||||
|           extra.push_back(StrCat("true_computation=\n", | ||||
|                                  true_computation()->ToString(new_options))); | ||||
|           extra.push_back(StrCat("false_computation=\n", | ||||
|                                  false_computation()->ToString(new_options))); | ||||
|         } else { | ||||
|           extra.push_back(StrCat( | ||||
|               "branch_computations={\n", | ||||
|               StrJoin(branch_computations(), ",\n", | ||||
|                       [&](string* out, const HloComputation* computation) { | ||||
|                         StrAppend(out, computation->ToString(new_options)); | ||||
|                       }), | ||||
|               "\n}")); | ||||
|         } | ||||
|         break; | ||||
|       case HloOpcode::kCall: | ||||
|       case HloOpcode::kMap: | ||||
|  | ||||
| @ -711,6 +711,11 @@ class HloInstruction { | ||||
|       HloInstruction* true_computation_arg, HloComputation* true_computation, | ||||
|       HloInstruction* false_computation_arg, HloComputation* false_computation); | ||||
| 
 | ||||
|   static std::unique_ptr<HloInstruction> CreateConditional( | ||||
|       const Shape& shape, HloInstruction* branch_index, | ||||
|       absl::Span<HloComputation* const> branch_computations, | ||||
|       absl::Span<HloInstruction* const> branch_computation_args); | ||||
| 
 | ||||
|   static std::unique_ptr<HloInstruction> CreateGather( | ||||
|       const Shape& shape, HloInstruction* operand, | ||||
|       HloInstruction* start_indices, | ||||
| @ -1057,14 +1062,23 @@ class HloInstruction { | ||||
| 
 | ||||
|   HloInstruction* while_init() const; | ||||
| 
 | ||||
|   // Gets/sets the true and false HloComputation for Conditional. The setters
 | ||||
|   // should only be called by HloModule or HloComputation methods.
 | ||||
|   // Gets/sets the true and false HloComputation for Conditional.
 | ||||
|   //
 | ||||
|   // Precondition: The instruction is a Conditional instruction.
 | ||||
|   // Precondition: The instruction is a predicated Conditional instruction.
 | ||||
|   HloComputation* true_computation() const; | ||||
|   HloComputation* false_computation() const; | ||||
|   void set_true_computation(HloComputation* true_computation); | ||||
|   void set_false_computation(HloComputation* false_computation); | ||||
| 
 | ||||
|   // Gets the branch HloComputations for Conditional.
 | ||||
|   //
 | ||||
|   // Precondition: The instruction is a Conditional instruction.
 | ||||
|   const std::vector<HloComputation*>& branch_computations() const; | ||||
|   int branch_count() const; | ||||
|   HloComputation* branch_computation(int b) const; | ||||
|   // Sets a branch HloComputation for Conditional.
 | ||||
|   // The setter should only be called by HloModule or HloComputation methods.
 | ||||
|   //
 | ||||
|   // Precondition: The instruction is a Conditional instruction.
 | ||||
|   void set_branch_computation(int b, HloComputation* computation); | ||||
| 
 | ||||
|   // Returns a string for the signature of this instruction if considered as a
 | ||||
|   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
 | ||||
|  | ||||
| @ -158,17 +158,12 @@ void HloModule::ReplaceComputations( | ||||
|           break; | ||||
|         } | ||||
|         case HloOpcode::kConditional: { | ||||
|           HloComputation* new_true_computation = | ||||
|               tensorflow::gtl::FindWithDefault( | ||||
|                   replacements, instruction->true_computation(), nullptr); | ||||
|           if (new_true_computation != nullptr) { | ||||
|             instruction->set_true_computation(new_true_computation); | ||||
|           } | ||||
|           HloComputation* new_false_computation = | ||||
|               tensorflow::gtl::FindWithDefault( | ||||
|                   replacements, instruction->false_computation(), nullptr); | ||||
|           if (new_false_computation != nullptr) { | ||||
|             instruction->set_false_computation(new_false_computation); | ||||
|           for (int b = 0; b < instruction->branch_count(); ++b) { | ||||
|             HloComputation* new_computation = tensorflow::gtl::FindWithDefault( | ||||
|                 replacements, instruction->branch_computation(b), nullptr); | ||||
|             if (new_computation != nullptr) { | ||||
|               instruction->set_branch_computation(b, new_computation); | ||||
|             } | ||||
|           } | ||||
|           break; | ||||
|         } | ||||
|  | ||||
| @ -45,11 +45,8 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { | ||||
|     case ComputationKind::kWhileBody: | ||||
|       repr += ":WHILE_BODY"; | ||||
|       break; | ||||
|     case ComputationKind::kConditionalTrue: | ||||
|       repr += ":CONDITIONAL_TRUE"; | ||||
|       break; | ||||
|     case ComputationKind::kConditionalFalse: | ||||
|       repr += ":CONDITIONAL_FALSE"; | ||||
|     case ComputationKind::kConditionalBranch: | ||||
|       repr += absl::StrCat(":CONDITIONAL_BRANCH_", index_); | ||||
|       break; | ||||
|     case ComputationKind::kCallFunction: | ||||
|       repr += ":CALL"; | ||||
| @ -307,10 +304,10 @@ Status HloModuleGroupMetadata::RecordInstructions() { | ||||
|       tracked_instructions_[hlo->while_body()] = | ||||
|           TrackedInstruction(hlo, ComputationKind::kWhileBody); | ||||
|     } else if (hlo->opcode() == HloOpcode::kConditional) { | ||||
|       tracked_instructions_[hlo->true_computation()] = | ||||
|           TrackedInstruction(hlo, ComputationKind::kConditionalTrue); | ||||
|       tracked_instructions_[hlo->false_computation()] = | ||||
|           TrackedInstruction(hlo, ComputationKind::kConditionalFalse); | ||||
|       for (int b = 0; b < hlo->branch_count(); ++b) { | ||||
|         tracked_instructions_[hlo->branch_computation(b)] = | ||||
|             TrackedInstruction(hlo, ComputationKind::kConditionalBranch, b); | ||||
|       } | ||||
|     } else if (hlo->opcode() == HloOpcode::kCall) { | ||||
|       tracked_instructions_[hlo->to_apply()] = | ||||
|           TrackedInstruction(hlo, ComputationKind::kCallFunction); | ||||
|  | ||||
| @ -67,8 +67,7 @@ class HloModuleGroupMetadata { | ||||
|     kInvalid, | ||||
|     kWhileCondition, | ||||
|     kWhileBody, | ||||
|     kConditionalTrue, | ||||
|     kConditionalFalse, | ||||
|     kConditionalBranch, | ||||
|     kCallFunction, | ||||
|   }; | ||||
| 
 | ||||
| @ -80,12 +79,13 @@ class HloModuleGroupMetadata { | ||||
|   class TrackedInstruction { | ||||
|    public: | ||||
|     TrackedInstruction() = default; | ||||
|     TrackedInstruction(HloInstruction* instruction, ComputationKind kind) | ||||
|         : instruction_(instruction), kind_(kind) {} | ||||
|     TrackedInstruction(HloInstruction* instruction, ComputationKind kind, | ||||
|                        int index = -1) | ||||
|         : instruction_(instruction), kind_(kind), index_(index) {} | ||||
| 
 | ||||
|     bool operator==(const TrackedInstruction& rhs) const { | ||||
|       return instruction_->opcode() == rhs.instruction_->opcode() && | ||||
|              kind_ == rhs.kind_; | ||||
|              kind_ == rhs.kind_ && index_ == rhs.index_; | ||||
|     } | ||||
|     bool operator!=(const TrackedInstruction& rhs) const { | ||||
|       return !operator==(rhs); | ||||
| @ -98,6 +98,7 @@ class HloModuleGroupMetadata { | ||||
|    private: | ||||
|     HloInstruction* instruction_ = nullptr; | ||||
|     ComputationKind kind_ = ComputationKind::kInvalid; | ||||
|     int index_ = -1; | ||||
|   }; | ||||
| 
 | ||||
|   // Represents a channel and the instructions that form the channel.
 | ||||
|  | ||||
| @ -67,7 +67,7 @@ namespace xla { | ||||
|   V(kClz, "count-leading-zeros", 1)                                    \ | ||||
|   V(kComplex, "complex", 2)                                            \ | ||||
|   V(kConcatenate, "concatenate", kHloOpcodeIsVariadic)                 \ | ||||
|   V(kConditional, "conditional", 3)                                    \ | ||||
|   V(kConditional, "conditional", kHloOpcodeIsVariadic)                 \ | ||||
|   V(kConstant, "constant", 0)                                          \ | ||||
|   V(kConvert, "convert", 1)                                            \ | ||||
|   V(kConvolution, "convolution", 2)                                    \ | ||||
|  | ||||
| @ -59,6 +59,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { | ||||
|       case HloOpcode::kAllToAll: | ||||
|       case HloOpcode::kCall: | ||||
|       case HloOpcode::kConcatenate: | ||||
|       case HloOpcode::kConditional: | ||||
|       case HloOpcode::kCustomCall: | ||||
|       case HloOpcode::kDynamicSlice: | ||||
|       case HloOpcode::kDynamicUpdateSlice: | ||||
|  | ||||
| @ -66,24 +66,31 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // If the common ancestor is a conditional instruction, even though the true
 | ||||
|   // and false computations are not really ordered per-se, we define the true
 | ||||
|   // computation to be ordered before the false one.
 | ||||
|   // This ensures that buffers can still be shared among the two computations
 | ||||
|   // If the common ancestor is a conditional instruction, even though the branch
 | ||||
|   // computations are not really ordered per-se, we define the 0th branch
 | ||||
|   // computation to be ordered before the 1st one, before the 2nd and so forth.
 | ||||
|   // This ensures that buffers can still be shared among branch computations
 | ||||
|   // as they will forcibly have disjoint liveness.
 | ||||
|   if (a_ancestor == b_ancestor && | ||||
|       a_ancestor->opcode() == HloOpcode::kConditional) { | ||||
|     const HloComputation* true_computation = a_ancestor->true_computation(); | ||||
|     const HloComputation* false_computation = a_ancestor->false_computation(); | ||||
|     if (call_graph_->InstructionIsNestedIn(a, true_computation) && | ||||
|         call_graph_->InstructionIsNestedIn(b, false_computation)) { | ||||
|       (a_ancestor->opcode() == HloOpcode::kConditional)) { | ||||
|     int a_branch = -1; | ||||
|     int b_branch = -1; | ||||
|     for (int j = 0; j < a_ancestor->branch_count(); ++j) { | ||||
|       if (call_graph_->InstructionIsNestedIn( | ||||
|               a, a_ancestor->branch_computation(j))) { | ||||
|         a_branch = j; | ||||
|       } | ||||
|       if (call_graph_->InstructionIsNestedIn( | ||||
|               b, a_ancestor->branch_computation(j))) { | ||||
|         b_branch = j; | ||||
|       } | ||||
|     } | ||||
|     if (a_branch != -1 && a_branch < b_branch) { | ||||
|       return true; | ||||
|     } | ||||
|     // If 'b' is the conditional ancestor, and 'a' is within the true or false
 | ||||
|     // computations, 'a' executes before 'b'.
 | ||||
|     if (b == a_ancestor && | ||||
|         (call_graph_->InstructionIsNestedIn(a, true_computation) || | ||||
|          call_graph_->InstructionIsNestedIn(a, false_computation))) { | ||||
|     // If 'b' is the conditional ancestor, and 'a' is within a branch
 | ||||
|     // computation, 'a' executes before 'b'.
 | ||||
|     if (b == a_ancestor && a_branch != -1) { | ||||
|       return true; | ||||
|     } | ||||
|   } | ||||
| @ -144,17 +151,17 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { | ||||
|            b.defining_instruction()->while_condition()))) { | ||||
|     return true; | ||||
|   } | ||||
|   // If 'b' is a conditional phi and 'a' is in the true or false computation,
 | ||||
|   // then 'a' executes before 'b'.
 | ||||
|   // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
 | ||||
|   // executes before 'b'.
 | ||||
|   if (b.is_phi() && | ||||
|       b.defining_instruction()->opcode() == HloOpcode::kConditional && | ||||
|       (call_graph_->InstructionIsNestedIn( | ||||
|            a.defining_instruction(), | ||||
|            b.defining_instruction()->true_computation()) || | ||||
|        call_graph_->InstructionIsNestedIn( | ||||
|            a.defining_instruction(), | ||||
|            b.defining_instruction()->false_computation()))) { | ||||
|     return true; | ||||
|       b.defining_instruction()->opcode() == HloOpcode::kConditional) { | ||||
|     for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) { | ||||
|       if (call_graph_->InstructionIsNestedIn( | ||||
|               a.defining_instruction(), | ||||
|               b.defining_instruction()->branch_computation(j))) { | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); | ||||
| } | ||||
| @ -225,17 +232,14 @@ bool HloOrdering::UseIsBeforeValueDefinition( | ||||
| 
 | ||||
|   if (use.instruction->opcode() == HloOpcode::kConditional) { | ||||
|     const HloInstruction* conditional = use.instruction; | ||||
|     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                            conditional->true_computation())) { | ||||
|       VLOG(4) << "  use is conditional " << use.instruction->name() | ||||
|               << " and def is in TRUE computation"; | ||||
|       return true; | ||||
|     } | ||||
|     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                            conditional->false_computation())) { | ||||
|       VLOG(4) << "  use is conditional " << use.instruction->name() | ||||
|               << " and def is in FALSE computation"; | ||||
|       return true; | ||||
|     for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|       if (call_graph_->InstructionIsNestedIn( | ||||
|               value.defining_instruction(), | ||||
|               conditional->branch_computation(j))) { | ||||
|         VLOG(4) << "  use is conditional " << use.instruction->name() | ||||
|                 << " and def is in " << j << "th branch computation"; | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|     if (value.defining_instruction() == use.instruction) { | ||||
|       VLOG(4) << "  use is conditional " << use << " and def is " | ||||
|  | ||||
| @ -22,6 +22,7 @@ limitations under the License. | ||||
| #include "absl/strings/str_format.h" | ||||
| #include "absl/strings/str_join.h" | ||||
| #include "absl/strings/str_split.h" | ||||
| #include "absl/types/span.h" | ||||
| #include "absl/types/variant.h" | ||||
| #include "tensorflow/compiler/xla/literal.h" | ||||
| #include "tensorflow/compiler/xla/literal_util.h" | ||||
| @ -180,6 +181,7 @@ class HloParser { | ||||
|     kBracedInt64List, | ||||
|     kBracedInt64ListList, | ||||
|     kHloComputation, | ||||
|     kBracedHloComputationList, | ||||
|     kFftType, | ||||
|     kWindow, | ||||
|     kConvolutionDimensionNumbers, | ||||
| @ -276,6 +278,8 @@ class HloParser { | ||||
| 
 | ||||
|   bool ParseSliceRanges(SliceRanges* result); | ||||
|   bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result); | ||||
|   bool ParseHloComputation(HloComputation** result); | ||||
|   bool ParseHloComputationList(std::vector<HloComputation*>* result); | ||||
|   bool ParseShapeList(std::vector<Shape>* result); | ||||
|   bool ParseInt64List(const TokKind start, const TokKind end, | ||||
|                       const TokKind delim, std::vector<int64>* result); | ||||
| @ -1436,18 +1440,36 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, | ||||
|     case HloOpcode::kConditional: { | ||||
|       optional<HloComputation*> true_computation; | ||||
|       optional<HloComputation*> false_computation; | ||||
|       attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, | ||||
|                                    &true_computation}; | ||||
|       attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, | ||||
|                                     &false_computation}; | ||||
|       if (!ParseOperands(&operands, /*expected_size=*/3) || | ||||
|           !ParseAttributes(attrs)) { | ||||
|       optional<std::vector<HloComputation*>> branch_computations; | ||||
|       if (!ParseOperands(&operands)) { | ||||
|         return false; | ||||
|       } | ||||
|       const bool branch_index_is_bool = | ||||
|           operands[0]->shape().element_type() == PRED; | ||||
|       if (branch_index_is_bool) { | ||||
|         attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, | ||||
|                                      &true_computation}; | ||||
|         attrs["false_computation"] = { | ||||
|             /*required=*/true, AttrTy::kHloComputation, &false_computation}; | ||||
|       } else { | ||||
|         attrs["branch_computations"] = {/*required=*/true, | ||||
|                                         AttrTy::kBracedHloComputationList, | ||||
|                                         &branch_computations}; | ||||
|       } | ||||
|       if (!ParseAttributes(attrs)) { | ||||
|         return false; | ||||
|       } | ||||
|       if (branch_index_is_bool) { | ||||
|         branch_computations.emplace({*true_computation, *false_computation}); | ||||
|       } | ||||
|       if (branch_computations->empty() || | ||||
|           operands.size() != branch_computations->size() + 1) { | ||||
|         return false; | ||||
|       } | ||||
|       instruction = builder->AddInstruction(HloInstruction::CreateConditional( | ||||
|           shape, /*pred=*/operands[0], | ||||
|           /*true_computation_arg=*/operands[1], *true_computation, | ||||
|           /*false_computation_arg=*/operands[2], *false_computation)); | ||||
|           shape, /*branch_index=*/operands[0], | ||||
|           absl::MakeSpan(*branch_computations), | ||||
|           absl::MakeSpan(operands).subspan(1))); | ||||
|       break; | ||||
|     } | ||||
|     case HloOpcode::kCustomCall: { | ||||
| @ -2683,20 +2705,21 @@ bool HloParser::ParseAttributeHelper( | ||||
|       } | ||||
|       case AttrTy::kHloComputation: { | ||||
|         HloComputation* result = nullptr; | ||||
|         if (lexer_.GetKind() == TokKind::kLbrace) { | ||||
|           // This means it is a nested computation.
 | ||||
|           if (!ParseInstructionList(&result, /*computation_name=*/"_")) { | ||||
|             return false; | ||||
|           } | ||||
|         } else { | ||||
|           // This means it is a computation name.
 | ||||
|           if (!ParseComputationName(&result)) { | ||||
|             return false; | ||||
|           } | ||||
|         if (!ParseHloComputation(&result)) { | ||||
|           return false; | ||||
|         } | ||||
|         static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); | ||||
|         return true; | ||||
|       } | ||||
|       case AttrTy::kBracedHloComputationList: { | ||||
|         std::vector<HloComputation*> result; | ||||
|         if (!ParseHloComputationList(&result)) { | ||||
|           return false; | ||||
|         } | ||||
|         static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr) | ||||
|             ->emplace(result); | ||||
|         return true; | ||||
|       } | ||||
|       case AttrTy::kFftType: { | ||||
|         FftType result; | ||||
|         if (!ParseFftType(&result)) { | ||||
| @ -3230,6 +3253,29 @@ bool HloParser::ParsePrecisionList( | ||||
|                    parse_and_add_item); | ||||
| } | ||||
| 
 | ||||
| bool HloParser::ParseHloComputation(HloComputation** result) { | ||||
|   if (lexer_.GetKind() == TokKind::kLbrace) { | ||||
|     // This means it is a nested computation.
 | ||||
|     return ParseInstructionList(result, /*computation_name=*/"_"); | ||||
|   } | ||||
|   // This means it is a computation name.
 | ||||
|   return ParseComputationName(result); | ||||
| } | ||||
| 
 | ||||
| bool HloParser::ParseHloComputationList(std::vector<HloComputation*>* result) { | ||||
|   auto parse_and_add_item = [&]() { | ||||
|     HloComputation* computation; | ||||
|     if (!ParseHloComputation(&computation)) { | ||||
|       return false; | ||||
|     } | ||||
|     LOG(INFO) << "parsed computation " << computation->name(); | ||||
|     result->push_back(computation); | ||||
|     return true; | ||||
|   }; | ||||
|   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, | ||||
|                    parse_and_add_item); | ||||
| } | ||||
| 
 | ||||
| // shapelist ::= '{' shapes '}'
 | ||||
| // precision_elements
 | ||||
| //   ::= /*empty*/
 | ||||
|  | ||||
| @ -952,6 +952,36 @@ ENTRY %ParseC128Literal () -> c128[2] { | ||||
|   ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)}) | ||||
| } | ||||
| 
 | ||||
| )" | ||||
| }, | ||||
| // Indexed Conditional
 | ||||
| { | ||||
| "IndexedConditional", | ||||
| R"(HloModule indexed_conditional | ||||
| 
 | ||||
| %Negate (x: f32[]) -> f32[] { | ||||
|   %x = f32[] parameter(0) | ||||
|   ROOT %negate = f32[] negate(f32[] %x) | ||||
| } | ||||
| 
 | ||||
| %Identity (y: f32[]) -> f32[] { | ||||
|   %y = f32[] parameter(0) | ||||
|   ROOT %copy = f32[] copy(f32[] %y) | ||||
| } | ||||
| 
 | ||||
| %Floor (z: f32[]) -> f32[] { | ||||
|   %z = f32[] parameter(0) | ||||
|   ROOT %floor = f32[] floor(f32[] %z) | ||||
| } | ||||
| 
 | ||||
| ENTRY %Parameters1.v4 () -> f32[] { | ||||
|   %constant = s32[] constant(1) | ||||
|   %constant.1 = f32[] constant(56) | ||||
|   %constant.2 = f32[] constant(12) | ||||
|   %constant.3 = f32[] constant(13) | ||||
|   ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor} | ||||
| } | ||||
| 
 | ||||
| )" | ||||
| }, | ||||
|   }); | ||||
| @ -1191,10 +1221,40 @@ ENTRY Sort { | ||||
| 
 | ||||
| )" | ||||
| }, | ||||
| // Conditional
 | ||||
| // Indexed Conditional
 | ||||
| { | ||||
| "Conditional", | ||||
| R"(HloModule conditional | ||||
| "IndexedConditional", | ||||
| R"(HloModule indexed_conditional | ||||
| 
 | ||||
| Negate { | ||||
|   x = f32[] parameter(0) | ||||
|   ROOT negate = f32[] negate(x) | ||||
| } | ||||
| 
 | ||||
| Identity { | ||||
|   y = f32[] parameter(0) | ||||
|   ROOT copy = f32[] copy(y) | ||||
| } | ||||
| 
 | ||||
| Floor { | ||||
|   z = f32[] parameter(0) | ||||
|   ROOT floor = f32[] floor(z) | ||||
| } | ||||
| 
 | ||||
| ENTRY Parameters1.v4 { | ||||
|   constant = s32[] constant(1) | ||||
|   constant.1 = f32[] constant(56) | ||||
|   constant.2 = f32[] constant(12) | ||||
|   constant.3 = f32[] constant(13) | ||||
|   ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor} | ||||
| } | ||||
| 
 | ||||
| )" | ||||
| }, | ||||
| // Predicated Conditional
 | ||||
| { | ||||
| "PredicatedConditional", | ||||
| R"(HloModule pred_conditional | ||||
| 
 | ||||
| Negate { | ||||
|   x = f32[] parameter(0) | ||||
| @ -2317,6 +2377,31 @@ TEST(HloParserSingleOpTest, CanonicalOpWithNested) { | ||||
|       text); | ||||
| } | ||||
| 
 | ||||
| TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) { | ||||
|   const string text = | ||||
|       R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={ | ||||
| { | ||||
|   tmp_0 = f32[5,10]{1,0} parameter(0) | ||||
|   ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0) | ||||
| }, | ||||
| { | ||||
|   tmp_0 = f32[5,10]{1,0} parameter(0) | ||||
|   ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0) | ||||
| }, | ||||
| { | ||||
|   tmp_0 = f32[5,10]{1,0} parameter(0) | ||||
|   ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0) | ||||
| } | ||||
| })"; | ||||
| 
 | ||||
|   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); | ||||
|   const HloComputation* computation = module->entry_computation(); | ||||
|   ASSERT_NE(computation, nullptr); | ||||
|   EXPECT_EQ( | ||||
|       computation->root_instruction()->ToString(HloPrintOptions::Canonical()), | ||||
|       text); | ||||
| } | ||||
| 
 | ||||
| TEST(HloParserSingleOpTest, SingleOpWithNested) { | ||||
|   const string text = | ||||
|       R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= | ||||
|  | ||||
| @ -667,20 +667,22 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { | ||||
| } | ||||
| 
 | ||||
| Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       CheckParameterCount(conditional, conditional->true_computation(), 1)); | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       CheckParameterCount(conditional, conditional->false_computation(), 1)); | ||||
|   TF_RETURN_IF_ERROR(CheckOperandAndParameter( | ||||
|       conditional, 1, conditional->true_computation(), 0)); | ||||
|   TF_RETURN_IF_ERROR(CheckOperandAndParameter( | ||||
|       conditional, 2, conditional->false_computation(), 0)); | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       CheckShape(conditional, | ||||
|                  conditional->true_computation()->root_instruction()->shape())); | ||||
|   TF_RETURN_IF_ERROR(CheckShape( | ||||
|       conditional, | ||||
|       conditional->false_computation()->root_instruction()->shape())); | ||||
|   const int num_branches = conditional->branch_count(); | ||||
|   if (conditional->operand(0)->shape().element_type() == PRED) { | ||||
|     TF_RET_CHECK(num_branches == 2); | ||||
|   } else { | ||||
|     TF_RET_CHECK(num_branches >= 1); | ||||
|   } | ||||
|   TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1)); | ||||
|   for (int j = 0; j < num_branches; ++j) { | ||||
|     TF_RETURN_IF_ERROR(CheckParameterCount( | ||||
|         conditional, conditional->branch_computation(j), 1)); | ||||
|     TF_RETURN_IF_ERROR(CheckOperandAndParameter( | ||||
|         conditional, j + 1, conditional->branch_computation(j), 0)); | ||||
|     TF_RETURN_IF_ERROR(CheckShape( | ||||
|         conditional, | ||||
|         conditional->branch_computation(j)->root_instruction()->shape())); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| @ -1370,17 +1372,13 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { | ||||
|   } | ||||
| 
 | ||||
|   Status HandleConditional(HloInstruction* conditional) override { | ||||
|     if (conditional->true_computation()->num_parameters() != 1) { | ||||
|       return FailedPrecondition( | ||||
|           "True computation %s of %s must have 1 parameter insted of %d", | ||||
|           conditional->true_computation()->name(), conditional->ToString(), | ||||
|           conditional->true_computation()->num_parameters()); | ||||
|     } | ||||
|     if (conditional->false_computation()->num_parameters() != 1) { | ||||
|       return FailedPrecondition( | ||||
|           "False computation %s of %s must have 1 parameter insted of %d", | ||||
|           conditional->false_computation()->name(), conditional->ToString(), | ||||
|           conditional->false_computation()->num_parameters()); | ||||
|     for (int b = 0; b < conditional->branch_count(); ++b) { | ||||
|       if (conditional->branch_computation(b)->num_parameters() != 1) { | ||||
|         return FailedPrecondition( | ||||
|             "Branch computation %s of %s must have 1 parameter insted of %d", | ||||
|             conditional->branch_computation(b)->name(), conditional->ToString(), | ||||
|             conditional->branch_computation(b)->num_parameters()); | ||||
|       } | ||||
|     } | ||||
|     return Status::OK(); | ||||
|   } | ||||
|  | ||||
| @ -588,48 +588,40 @@ Status LayoutAssignment::AddMandatoryConstraints( | ||||
|       TF_RETURN_IF_ERROR(constraints->SetOperandLayout( | ||||
|           body_layout.result_shape(), instruction, 0)); | ||||
|     } else if (instruction->opcode() == HloOpcode::kConditional) { | ||||
|       // The layout of the true and false computations must match, and must
 | ||||
|       // The layout of the branch computations must match, and must
 | ||||
|       // be the layout of the kConditional instruction.
 | ||||
|       TF_RET_CHECK(instruction->operand_count() == 3); | ||||
|       ComputationLayout& branch0_computation_layout = | ||||
|           FindOrDie(computation_layouts_, instruction->branch_computation(0)); | ||||
|       for (int j = 0; j < instruction->branch_count(); ++j) { | ||||
|         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1); | ||||
|         ComputationLayout& branch_computation_layout = | ||||
|             FindOrDie(computation_layouts_, instruction->branch_computation(j)); | ||||
| 
 | ||||
|       HloComputation* true_computation = instruction->true_computation(); | ||||
|       HloComputation* false_computation = instruction->false_computation(); | ||||
|       const HloInstruction* true_operand = instruction->operand(1); | ||||
|       const HloInstruction* false_operand = instruction->operand(2); | ||||
| 
 | ||||
|       TF_RET_CHECK(true_computation->num_parameters() == 1); | ||||
|       TF_RET_CHECK(false_computation->num_parameters() == 1); | ||||
|       ComputationLayout& true_computation_layout = | ||||
|           FindOrDie(computation_layouts_, true_computation); | ||||
|       ComputationLayout& false_computation_layout = | ||||
|           FindOrDie(computation_layouts_, false_computation); | ||||
| 
 | ||||
|       DCHECK(ShapeUtil::Compatible(true_operand->shape(), | ||||
|                                    true_computation_layout.parameter_shape(0))); | ||||
|       DCHECK(ShapeUtil::Compatible( | ||||
|           false_operand->shape(), false_computation_layout.parameter_shape(0))); | ||||
|       if (true_computation_layout.result_layout() != | ||||
|           false_computation_layout.result_layout()) { | ||||
|         // We assign layouts in DFS fashion, so the true and false computations
 | ||||
|         // might have negotiated a different layout. But for the conditional
 | ||||
|         // instruction POV the layout must match, so we run again on the false
 | ||||
|         // computation, this time with proper computation layout.
 | ||||
|         VLOG(2) << "Reset %conditional false computation result layout: " | ||||
|                    "false_computation=" | ||||
|                 << false_computation->name() | ||||
|                 << " conditional=" << instruction->name() << " shape=" | ||||
|                 << true_computation_layout.result_layout().ToString(); | ||||
|         *false_computation_layout.mutable_result_layout() = | ||||
|             true_computation_layout.result_layout(); | ||||
|         DCHECK(ShapeUtil::Compatible( | ||||
|             instruction->operand(j + 1)->shape(), | ||||
|             branch_computation_layout.parameter_shape(0))); | ||||
|         if (branch0_computation_layout.result_layout() != | ||||
|             branch_computation_layout.result_layout()) { | ||||
|           // We assign layouts in DFS fashion, so the br_0 and br_j
 | ||||
|           // computations might have negotiated a different layout. But for the
 | ||||
|           // case instruction POV the layout must match, so we run again
 | ||||
|           // on the br_j computation, this time with proper computation layout.
 | ||||
|           VLOG(2) << "Reset %conditional branch " << j | ||||
|                   << " computation result layout: branch_computation=" | ||||
|                   << instruction->branch_computation(j)->name() | ||||
|                   << " case=" << instruction->name() << " shape=" | ||||
|                   << branch0_computation_layout.result_layout().ToString(); | ||||
|           *branch_computation_layout.mutable_result_layout() = | ||||
|               branch0_computation_layout.result_layout(); | ||||
|         } | ||||
|         if (j == 0) { | ||||
|           TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( | ||||
|               branch0_computation_layout.result_shape(), instruction)); | ||||
|         } | ||||
|         TF_RETURN_IF_ERROR(constraints->SetOperandLayout( | ||||
|             branch_computation_layout.parameter_shape(0), instruction, j + 1, | ||||
|             /*mandatory=*/true)); | ||||
|       } | ||||
|       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( | ||||
|           true_computation_layout.result_shape(), instruction)); | ||||
|       TF_RETURN_IF_ERROR(constraints->SetOperandLayout( | ||||
|           true_computation_layout.parameter_shape(0), instruction, 1, | ||||
|           /*mandatory=*/true)); | ||||
|       TF_RETURN_IF_ERROR(constraints->SetOperandLayout( | ||||
|           false_computation_layout.parameter_shape(0), instruction, 2, | ||||
|           /*mandatory=*/true)); | ||||
|     } | ||||
|   } | ||||
|   // Finally set the result layout to match ComputationLayout, if there is one.
 | ||||
| @ -699,28 +691,21 @@ Status CheckWhileLayout(HloInstruction* while_inst, | ||||
| 
 | ||||
| Status CheckConditionalLayout( | ||||
|     HloInstruction* instruction, | ||||
|     const ComputationLayout& true_computation_layout, | ||||
|     const ComputationLayout& false_computation_layout) { | ||||
|   HloComputation* true_computation = instruction->true_computation(); | ||||
|   HloComputation* false_computation = instruction->false_computation(); | ||||
|   const HloInstruction* true_operand = instruction->operand(1); | ||||
|   const HloInstruction* false_operand = instruction->operand(2); | ||||
| 
 | ||||
|   TF_RET_CHECK(true_computation_layout.result_layout() == | ||||
|                false_computation_layout.result_layout()); | ||||
|   TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( | ||||
|       instruction->shape())); | ||||
|   TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( | ||||
|       true_computation->root_instruction()->shape())); | ||||
|   TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( | ||||
|       instruction->shape())); | ||||
|   TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( | ||||
|       false_computation->root_instruction()->shape())); | ||||
|   TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( | ||||
|       true_operand->shape())); | ||||
|   TF_RET_CHECK( | ||||
|       false_computation_layout.parameter_layout(0).MatchesLayoutInShape( | ||||
|           false_operand->shape())); | ||||
|     absl::Span<const ComputationLayout> branch_computation_layouts) { | ||||
|   for (int j = 0; j < instruction->branch_count(); ++j) { | ||||
|     const HloInstruction* branch_operand = instruction->operand(j + 1); | ||||
|     TF_RET_CHECK(branch_computation_layouts[0].result_layout() == | ||||
|                  branch_computation_layouts[j].result_layout()); | ||||
|     TF_RET_CHECK( | ||||
|         branch_computation_layouts[j].result_layout().MatchesLayoutInShape( | ||||
|             instruction->shape())); | ||||
|     TF_RET_CHECK( | ||||
|         branch_computation_layouts[j].result_layout().MatchesLayoutInShape( | ||||
|             instruction->branch_computation(j)->root_instruction()->shape())); | ||||
|     TF_RET_CHECK( | ||||
|         branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape( | ||||
|             branch_operand->shape())); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| @ -937,13 +922,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { | ||||
|               FindOrDie(computation_layouts_, instruction->while_condition()), | ||||
|               FindOrDie(computation_layouts_, instruction->while_body()))); | ||||
|           break; | ||||
|         case HloOpcode::kConditional: | ||||
|         case HloOpcode::kConditional: { | ||||
|           std::vector<ComputationLayout> branch_computation_layouts; | ||||
|           for (auto branch_computation : instruction->branch_computations()) { | ||||
|             branch_computation_layouts.emplace_back( | ||||
|                 FindOrDie(computation_layouts_, branch_computation)); | ||||
|           } | ||||
|           TF_RETURN_IF_ERROR(CheckConditionalLayout( | ||||
|               instruction, | ||||
|               FindOrDie(computation_layouts_, instruction->true_computation()), | ||||
|               FindOrDie(computation_layouts_, | ||||
|                         instruction->false_computation()))); | ||||
|               instruction, absl::MakeSpan(branch_computation_layouts))); | ||||
|           break; | ||||
|         } | ||||
|         default: | ||||
|           break; | ||||
|       } | ||||
|  | ||||
| @ -1291,16 +1291,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       ExpectArray(scale_shape, "scale input of batch norm inference")); | ||||
| 
 | ||||
|   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == | ||||
|                Status::OK()); | ||||
|   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == | ||||
|                Status::OK()); | ||||
|   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == | ||||
|                Status::OK()); | ||||
|   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == | ||||
|                Status::OK()); | ||||
|   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == | ||||
|                Status::OK()); | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape)); | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); | ||||
|   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape)); | ||||
| 
 | ||||
|   if (feature_index >= operand_shape.rank()) { | ||||
|     return InvalidArgument( | ||||
| @ -2562,59 +2558,55 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, | ||||
| } | ||||
| 
 | ||||
| /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape( | ||||
|     const Shape& predicate, const Shape& true_operand, | ||||
|     const Shape& false_operand, const ProgramShape& true_computation, | ||||
|     const ProgramShape& false_computation) { | ||||
|   if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) { | ||||
|     return InvalidArgument("Predicate must be a boolean; got %s.", | ||||
|                            ShapeUtil::HumanString(predicate)); | ||||
|     const Shape& branch_index, | ||||
|     absl::Span<const ProgramShape> branch_computations, | ||||
|     absl::Span<const Shape> branch_operands) { | ||||
|   if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) && | ||||
|       !ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) { | ||||
|     return InvalidArgument("branch_index must be bool or int32; got %s.", | ||||
|                            ShapeUtil::HumanString(branch_index)); | ||||
|   } | ||||
|   if (branch_index.element_type() == PRED) { | ||||
|     TF_RET_CHECK(2 == branch_computations.size()); | ||||
|   } else { | ||||
|     TF_RET_CHECK(!branch_computations.empty()); | ||||
|   } | ||||
|   TF_RET_CHECK(branch_computations.size() == branch_operands.size()); | ||||
| 
 | ||||
|   if (true_computation.parameters_size() != 1) { | ||||
|     return InvalidArgument("true_computation must take 1 argument; got %d.", | ||||
|                            true_computation.parameters_size()); | ||||
|   } | ||||
|   if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { | ||||
|     auto true_shape_string = [&]() { | ||||
|       return StrFormat("true_operand: %s; true_computation: %s", | ||||
|                        ShapeUtil::HumanString(true_operand), | ||||
|                        ShapeUtil::HumanString(true_computation)); | ||||
|     }; | ||||
|     return InvalidArgument( | ||||
|         "true_operand must match the shape of the only parameter of " | ||||
|         "true_computation: got %s.", | ||||
|         true_shape_string()); | ||||
|   } | ||||
|   for (int j = 0; j < branch_computations.size(); ++j) { | ||||
|     if (branch_computations[j].parameters_size() != 1) { | ||||
|       return InvalidArgument( | ||||
|           "branch computation %d must take 1 argument; got %d.", j, | ||||
|           branch_computations[j].parameters_size()); | ||||
|     } | ||||
|     if (!ShapeUtil::Compatible(branch_computations[j].parameters(0), | ||||
|                                branch_operands[j])) { | ||||
|       auto shape_string = [&]() { | ||||
|         return StrFormat("operand: %s; computation: %s", | ||||
|                          ShapeUtil::HumanString(branch_operands[j]), | ||||
|                          ShapeUtil::HumanString(branch_computations[j])); | ||||
|       }; | ||||
|       return InvalidArgument( | ||||
|           "branch operand %d must match the shape of the only parameter of " | ||||
|           "branch computation %d: got %s.", | ||||
|           j, j, shape_string()); | ||||
|     } | ||||
| 
 | ||||
|   if (false_computation.parameters_size() != 1) { | ||||
|     return InvalidArgument("false_computation must take 1 argument; got %d.", | ||||
|                            false_computation.parameters_size()); | ||||
|     if (!ShapeUtil::Compatible(branch_computations[0].result(), | ||||
|                                branch_computations[j].result())) { | ||||
|       auto shape_string = [&]() { | ||||
|         return StrFormat( | ||||
|             "branch 0 computation result: %s; branch %d computation result: %s", | ||||
|             ShapeUtil::HumanString(branch_computations[0].result()), j, | ||||
|             ShapeUtil::HumanString(branch_computations[j].result())); | ||||
|       }; | ||||
|       return InvalidArgument( | ||||
|           "the result of branch 0 computation and branch %d computation must " | ||||
|           "have the same shape: got %s.", | ||||
|           j, shape_string()); | ||||
|     } | ||||
|   } | ||||
|   if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { | ||||
|     auto false_shape_string = [&]() { | ||||
|       return StrFormat("false_operand: %s; false_computation: %s", | ||||
|                        ShapeUtil::HumanString(false_operand), | ||||
|                        ShapeUtil::HumanString(false_computation)); | ||||
|     }; | ||||
|     return InvalidArgument( | ||||
|         "false_operand must match the shape of the only parameter of " | ||||
|         "false_computation: got %s.", | ||||
|         false_shape_string()); | ||||
|   } | ||||
|   if (!ShapeUtil::Compatible(true_computation.result(), | ||||
|                              false_computation.result())) { | ||||
|     auto shape_string = [&]() { | ||||
|       return StrFormat( | ||||
|           "true_computation result: %s; false_computation result: %s.", | ||||
|           ShapeUtil::HumanString(true_computation.result()), | ||||
|           ShapeUtil::HumanString(false_computation.result())); | ||||
|     }; | ||||
|     return InvalidArgument( | ||||
|         "the result of true_computation and false_computation must have the " | ||||
|         "same shape: got %s.", | ||||
|         shape_string()); | ||||
|   } | ||||
|   return true_computation.result(); | ||||
|   return branch_computations[0].result(); | ||||
| } | ||||
| 
 | ||||
| /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape( | ||||
|  | ||||
| @ -208,11 +208,11 @@ class ShapeInference { | ||||
|                                          const ProgramShape& body, | ||||
|                                          const Shape& init); | ||||
| 
 | ||||
|   // Infers the shape produced by a conditional operation.
 | ||||
|   // Infers the shape produced by a predicated or indexed conditional operation.
 | ||||
|   static StatusOr<Shape> InferConditionalShape( | ||||
|       const Shape& predicate, const Shape& true_operand, | ||||
|       const Shape& false_operand, const ProgramShape& true_computation, | ||||
|       const ProgramShape& false_computation); | ||||
|       const Shape& branch_index, | ||||
|       absl::Span<const ProgramShape> branch_computations, | ||||
|       absl::Span<const Shape> branch_operands); | ||||
| 
 | ||||
|   // Infers the shape produced by a broadcast operation.
 | ||||
|   static StatusOr<Shape> InferBroadcastShape( | ||||
|  | ||||
| @ -1582,79 +1582,166 @@ TEST_F(ShapeInferenceTest, Rank1Transpose) { | ||||
|       ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ShapeInferenceTest, Conditional) { | ||||
| TEST_F(ShapeInferenceTest, ConditionalPred) { | ||||
|   auto inferred_status0 = ShapeInference::InferConditionalShape( | ||||
|       pred_, vector_32_, vector_64_, | ||||
|       ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|       ShapeUtil::MakeProgramShape({vector_64_}, f32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, | ||||
|       {vector_32_, vector_64_}); | ||||
|   EXPECT_IS_OK(inferred_status0.status()); | ||||
|   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); | ||||
| 
 | ||||
|   auto inferred_status1 = ShapeInference::InferConditionalShape( | ||||
|       pred_, matrix_32_48_, vector_32_, | ||||
|       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), | ||||
|       ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)}, | ||||
|       {matrix_32_48_, vector_32_}); | ||||
|   EXPECT_IS_OK(inferred_status1.status()); | ||||
|   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); | ||||
| 
 | ||||
|   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); | ||||
|   auto inferred_status2 = ShapeInference::InferConditionalShape( | ||||
|       pred_, matrix_32_48_, tuple_f32_v32, | ||||
|       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), | ||||
|       ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), | ||||
|        ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, | ||||
|       {matrix_32_48_, tuple_f32_v32}); | ||||
|   EXPECT_IS_OK(inferred_status2.status()); | ||||
|   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); | ||||
| 
 | ||||
|   auto inferred_status_error0 = ShapeInference::InferConditionalShape( | ||||
|       s32_, vector_32_, vector_64_, | ||||
|       ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|       ShapeUtil::MakeProgramShape({vector_64_}, f32_)); | ||||
|       f32_, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, | ||||
|       {vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error0.ok()); | ||||
|   EXPECT_THAT(inferred_status_error0.status().error_message(), | ||||
|               HasSubstr("Predicate must be a boolean")); | ||||
|               HasSubstr("must be bool or int32")); | ||||
| 
 | ||||
|   auto inferred_status_error1 = ShapeInference::InferConditionalShape( | ||||
|       pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, | ||||
|       ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), | ||||
|       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), | ||||
|        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, | ||||
|       {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); | ||||
|   EXPECT_FALSE(inferred_status_error1.ok()); | ||||
|   EXPECT_THAT(inferred_status_error1.status().error_message(), | ||||
|               HasSubstr("true_computation must take 1 argument")); | ||||
|               HasSubstr("branch computation 0 must take 1 argument")); | ||||
| 
 | ||||
|   auto inferred_status_error2 = ShapeInference::InferConditionalShape( | ||||
|       pred_, vector_32_, vector_64_, | ||||
|       ShapeUtil::MakeProgramShape({vector_64_}, f32_), | ||||
|       ShapeUtil::MakeProgramShape({vector_64_}, f32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({vector_64_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, | ||||
|       {vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error2.ok()); | ||||
|   EXPECT_THAT(inferred_status_error2.status().error_message(), | ||||
|               HasSubstr("true_operand must match the shape of the only " | ||||
|                         "parameter of true_computation")); | ||||
|               HasSubstr("branch operand 0 must match the shape of the only " | ||||
|                         "parameter of branch computation 0")); | ||||
| 
 | ||||
|   auto inferred_status_error3 = ShapeInference::InferConditionalShape( | ||||
|       pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), | ||||
|       ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), | ||||
|       ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), | ||||
|        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)}, | ||||
|       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})}); | ||||
|   EXPECT_FALSE(inferred_status_error3.ok()); | ||||
|   EXPECT_THAT(inferred_status_error3.status().error_message(), | ||||
|               HasSubstr("false_computation must take 1 argument")); | ||||
|               HasSubstr("branch computation 1 must take 1 argument")); | ||||
| 
 | ||||
|   auto inferred_status_error4 = ShapeInference::InferConditionalShape( | ||||
|       pred_, vector_32_, vector_64_, | ||||
|       ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|       ShapeUtil::MakeProgramShape({vector_32_}, f32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, | ||||
|       {vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error4.ok()); | ||||
|   EXPECT_THAT(inferred_status_error4.status().error_message(), | ||||
|               HasSubstr("false_operand must match the shape of the only " | ||||
|                         "parameter of false_computation")); | ||||
|               HasSubstr("branch operand 1 must match the shape of the only " | ||||
|                         "parameter of branch computation 1")); | ||||
| 
 | ||||
|   auto inferred_status_error5 = ShapeInference::InferConditionalShape( | ||||
|       pred_, vector_32_, vector_64_, | ||||
|       ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|       ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, | ||||
|       {vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error5.ok()); | ||||
|   EXPECT_THAT(inferred_status_error5.status().error_message(), | ||||
|               HasSubstr("the result of true_computation and false_computation " | ||||
|                         "must have the same shape")); | ||||
|               HasSubstr("the result of branch 0 computation and branch 1 " | ||||
|                         "computation must have the same shape")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ShapeInferenceTest, ConditionalIndexed) { | ||||
|   auto r0s32 = ShapeUtil::MakeShape(S32, {}); | ||||
|   auto inferred_status0 = ShapeInference::InferConditionalShape( | ||||
|       r0s32, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, | ||||
|       {vector_32_, vector_64_, vector_64_}); | ||||
|   EXPECT_IS_OK(inferred_status0.status()); | ||||
|   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); | ||||
| 
 | ||||
|   auto inferred_status1 = ShapeInference::InferConditionalShape( | ||||
|       r0s32, | ||||
|       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_), | ||||
|        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)}, | ||||
|       {matrix_32_48_, vector_32_, matrix_32_48_}); | ||||
|   EXPECT_IS_OK(inferred_status1.status()); | ||||
|   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); | ||||
| 
 | ||||
|   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); | ||||
|   auto inferred_status2 = ShapeInference::InferConditionalShape( | ||||
|       r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, | ||||
|       {tuple_f32_v32}); | ||||
|   EXPECT_IS_OK(inferred_status2.status()); | ||||
|   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); | ||||
| 
 | ||||
|   auto inferred_status_error0 = ShapeInference::InferConditionalShape( | ||||
|       pred_, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, | ||||
|       {vector_32_, vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error0.ok()); | ||||
|   EXPECT_THAT(inferred_status_error0.status().error_message(), | ||||
|               HasSubstr("2 == branch_computations.size()")); | ||||
| 
 | ||||
|   auto inferred_status_error1 = ShapeInference::InferConditionalShape( | ||||
|       r0s32, | ||||
|       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), | ||||
|        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), | ||||
|        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, | ||||
|       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), | ||||
|        matrix_32_48_}); | ||||
|   EXPECT_FALSE(inferred_status_error1.ok()); | ||||
|   EXPECT_THAT(inferred_status_error1.status().error_message(), | ||||
|               HasSubstr("branch computation 1 must take 1 argument")); | ||||
| 
 | ||||
|   auto inferred_status_error2 = ShapeInference::InferConditionalShape( | ||||
|       r0s32, | ||||
|       {ShapeUtil::MakeProgramShape({r0s32}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, | ||||
|       {r0s32, vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error2.ok()); | ||||
|   EXPECT_THAT(inferred_status_error2.status().error_message(), | ||||
|               HasSubstr("branch operand 2 must match the shape of the only " | ||||
|                         "parameter of branch computation 2")); | ||||
| 
 | ||||
|   auto inferred_status_error3 = ShapeInference::InferConditionalShape( | ||||
|       r0s32, | ||||
|       {ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_32_}, f32_), | ||||
|        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, | ||||
|       {vector_32_, vector_32_, vector_32_, vector_64_}); | ||||
|   EXPECT_FALSE(inferred_status_error3.ok()); | ||||
|   EXPECT_THAT(inferred_status_error3.status().error_message(), | ||||
|               HasSubstr("the result of branch 0 computation and branch 3 " | ||||
|                         "computation must have the same shape")); | ||||
| 
 | ||||
|   auto inferred_status_error4 = | ||||
|       ShapeInference::InferConditionalShape(r0s32, {}, {}); | ||||
|   EXPECT_FALSE(inferred_status_error4.ok()); | ||||
|   EXPECT_THAT(inferred_status_error4.status().error_message(), | ||||
|               HasSubstr("!branch_computations.empty()")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(ShapeInferenceTest, BadSlice) { | ||||
|  | ||||
| @ -552,6 +552,7 @@ xla_test( | ||||
| xla_test( | ||||
|     name = "conditional_test", | ||||
|     srcs = ["conditional_test.cc"], | ||||
|     shard_count = 2, | ||||
|     deps = [ | ||||
|         "//tensorflow/compiler/xla:xla_data_proto", | ||||
|         "//tensorflow/compiler/xla/client:global_data", | ||||
|  | ||||
| @ -13,6 +13,7 @@ See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include <random> | ||||
| #include "tensorflow/compiler/xla/client/xla_builder.h" | ||||
| #include "tensorflow/compiler/xla/client/xla_computation.h" | ||||
| #include "tensorflow/compiler/xla/tests/client_library_test_base.h" | ||||
| @ -169,6 +170,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { | ||||
|   ErrorSpec error_spec_{0.001}; | ||||
| }; | ||||
| 
 | ||||
| // Test fixture to run indexed conditional (switch/case) tests with varying
 | ||||
| // number of branches.
 | ||||
| class CaseOpTest : public ConditionalOpTest, | ||||
|                    public ::testing::WithParamInterface<int> {}; | ||||
| 
 | ||||
| // Test true and false computations that do not take any parameters.
 | ||||
| XLA_TEST_F(ConditionalOpTest, Parameters0) { | ||||
|   XlaBuilder builder(TestName()); | ||||
| @ -182,6 +188,36 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) { | ||||
|   ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_); | ||||
| } | ||||
| 
 | ||||
| // Test branch computations that do not take any parameters.
 | ||||
| XLA_TEST_P(CaseOpTest, Parameters0) { | ||||
|   int num_branches = GetParam(); | ||||
|   for (int bi = -1; bi <= num_branches; ++bi) { | ||||
|     SCOPED_TRACE(bi); | ||||
|     XlaBuilder builder(TestName()); | ||||
|     XlaOp branch_index; | ||||
|     auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg", | ||||
|                                                      &builder, &branch_index); | ||||
|     auto operand = Tuple(&builder, {}); | ||||
| 
 | ||||
|     std::vector<XlaOp> operands(num_branches, operand); | ||||
|     std::vector<XlaComputation> branches; | ||||
|     branches.reserve(num_branches); | ||||
|     std::vector<const XlaComputation*> branches_p(num_branches); | ||||
|     for (int i = 0; i < num_branches; ++i) { | ||||
|       branches.emplace_back( | ||||
|           CreateR0ConstantComputation(static_cast<float>(i) * 10)); | ||||
|       branches_p[i] = &branches[i]; | ||||
|     } | ||||
|     Conditional(branch_index, branches_p, operands); | ||||
| 
 | ||||
|     float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches) | ||||
|                                                  ? num_branches - 1 | ||||
|                                                  : bi); | ||||
|     ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()}, | ||||
|                                error_spec_); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // Test true and false computations that take in 1 parameter.
 | ||||
| XLA_TEST_F(ConditionalOpTest, Parameters1) { | ||||
|   XlaBuilder builder(TestName()); | ||||
| @ -195,6 +231,45 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { | ||||
|   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); | ||||
| } | ||||
| 
 | ||||
| // Test branch computations that take in 1 parameter.
 | ||||
| XLA_TEST_P(CaseOpTest, Parameters1) { | ||||
|   int num_branches = GetParam(); | ||||
|   for (int bi = -1; bi <= num_branches; ++bi) { | ||||
|     SCOPED_TRACE(bi); | ||||
|     XlaBuilder builder(TestName()); | ||||
|     XlaOp branch_index; | ||||
|     auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg", | ||||
|                                                      &builder, &branch_index); | ||||
| 
 | ||||
|     auto make_branch = [&builder, this](int i) { | ||||
|       auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); | ||||
|       Add(ConstantR0<float>(sb.get(), static_cast<float>(i)), | ||||
|           Parameter(sb.get(), 0, r0f32_, "p0")); | ||||
|       return sb->BuildAndNoteError(); | ||||
|     }; | ||||
|     std::vector<XlaComputation> branches; | ||||
|     branches.reserve(num_branches); | ||||
|     std::vector<const XlaComputation*> branches_p(num_branches); | ||||
|     std::vector<XlaOp> operands; | ||||
|     operands.reserve(num_branches); | ||||
|     std::vector<float> expecteds(num_branches); | ||||
|     for (int i = 0; i < num_branches; ++i) { | ||||
|       branches.emplace_back(make_branch(i)); | ||||
|       branches_p[i] = &branches[i]; | ||||
|       auto fi = static_cast<float>(i); | ||||
|       operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7)); | ||||
|       expecteds[i] = 10 * fi + 7 + fi; | ||||
|     } | ||||
| 
 | ||||
|     Conditional(branch_index, branches_p, operands); | ||||
|     float expected = (bi < 0 || bi >= num_branches) | ||||
|                          ? expecteds[num_branches - 1] | ||||
|                          : expecteds[bi]; | ||||
|     ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()}, | ||||
|                                error_spec_); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // Test conditional with two different computations in the true and false cases
 | ||||
| // that take in different arguments.
 | ||||
| XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { | ||||
| @ -331,6 +406,46 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { | ||||
|                              error_spec_); | ||||
| } | ||||
| 
 | ||||
| // Test branch computations that take in 2 array parameters.
 | ||||
| XLA_TEST_P(CaseOpTest, Parameters2Array) { | ||||
|   int num_branches = GetParam(); | ||||
|   for (int bi = -1; bi <= num_branches; ++bi) { | ||||
|     SCOPED_TRACE(bi); | ||||
|     XlaBuilder builder(TestName()); | ||||
|     XlaOp branch_index; | ||||
|     auto branch_index_arg = | ||||
|         CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index); | ||||
|     auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f}); | ||||
|     auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f}); | ||||
|     auto operands = Tuple(&builder, {operand1, operand2}); | ||||
|     auto make_branch = [&builder, this](int i) { | ||||
|       auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i)); | ||||
|       auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0"); | ||||
|       Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)), | ||||
|               GetTupleElement(p, 0)), | ||||
|           GetTupleElement(p, 1)); | ||||
|       return sb->BuildAndNoteError(); | ||||
|     }; | ||||
|     std::vector<XlaComputation> branches; | ||||
|     branches.reserve(num_branches); | ||||
|     std::vector<const XlaComputation*> branches_p(num_branches); | ||||
|     for (int i = 0; i < num_branches; ++i) { | ||||
|       branches.emplace_back(make_branch(i)); | ||||
|       branches_p[i] = &branches[i]; | ||||
|     } | ||||
|     Conditional(branch_index, branches_p, | ||||
|                 std::vector<XlaOp>(num_branches, operands)); | ||||
|     auto modified_bi = static_cast<float>( | ||||
|         (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi); | ||||
|     ComputeAndCompareR1<float>( | ||||
|         &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11}, | ||||
|         {branch_index_arg.get()}, error_spec_); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest, | ||||
|                          ::testing::Values(1, 2, 3, 4, 5)); | ||||
| 
 | ||||
| // Test true and false computations that take in 2 array parameters and
 | ||||
| // predicate is false.
 | ||||
| XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { | ||||
| @ -582,8 +697,8 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { | ||||
|   auto result = builder.Build(); | ||||
|   EXPECT_FALSE(result.ok()); | ||||
|   EXPECT_THAT(result.status().error_message(), | ||||
|               ::testing::HasSubstr("true_operand must match the shape of the " | ||||
|                                    "only parameter of true_computation")); | ||||
|               ::testing::HasSubstr("operand 0 must match the shape of the " | ||||
|                                    "only parameter of branch computation 0")); | ||||
| } | ||||
| 
 | ||||
| XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user