[XLA] Support dynamic conditional input/outputs in dynamic padder.
Dynamic dimensiosn of inputs and outputs are passed as additional tuple elements. PiperOrigin-RevId: 306277465 Change-Id: Ica1d69a813fa504fb228a63aad2226d4a51078db
This commit is contained in:
parent
b2c5581d4c
commit
519326041d
@ -2403,9 +2403,12 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":tuple_util",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
@ -23,12 +26,45 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
// Replace `narrow_comp` with a new computation with `wide_shape` as input.
|
||||
StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
|
||||
const Shape& wide_shape) {
|
||||
TF_RET_CHECK(wide_shape.IsTuple());
|
||||
const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
|
||||
if (Shape::Equal()(wide_shape, narrow_shape)) {
|
||||
// No need to widen the computation.
|
||||
return narrow_comp;
|
||||
}
|
||||
HloComputation* wide_comp = [&]() {
|
||||
HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
|
||||
return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
|
||||
}();
|
||||
|
||||
HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
|
||||
HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
|
||||
wide_parameter, narrow_shape.tuple_shapes_size());
|
||||
HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
|
||||
HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
|
||||
{truncated_parameter}, narrow_comp));
|
||||
wide_comp->set_root_instruction(call_narrow_comp,
|
||||
/*accept_different_shape=*/true);
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
|
||||
return wide_comp;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit DynamicDimensionInferenceVisitor(
|
||||
@ -95,6 +131,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleClamp(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleConditional(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleWhile(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleSlice(HloInstruction* hlo) override;
|
||||
@ -116,15 +154,21 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint)>;
|
||||
|
||||
using DynamicDimensionFn = std::function<Status(
|
||||
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint)>;
|
||||
|
||||
Status ForEachOperandDynamicDimension(HloInstruction* inst,
|
||||
const OperandDynamicDimensionFn&);
|
||||
Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
|
||||
int64 operand_index,
|
||||
const OperandDynamicDimensionFn&);
|
||||
Status ForEachDynamicDimension(HloInstruction* inst,
|
||||
const DynamicDimensionFn& fn);
|
||||
|
||||
// Pass through a dynamic dimension from the input to the output with the same
|
||||
// value and index in the shape. This is a helper function to handle trivial
|
||||
// instructions like elementwise operations.
|
||||
// Pass through a dynamic dimension from the input to the output with the
|
||||
// same value and index in the shape. This is a helper function to handle
|
||||
// trivial instructions like elementwise operations.
|
||||
Status PassThroughDynamicDimension(HloInstruction*);
|
||||
|
||||
// The dynamic parameter bindings of this computation.
|
||||
@ -1139,6 +1183,163 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
|
||||
});
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleConditional(
|
||||
HloInstruction* hlo) {
|
||||
// Conditionals are handled by producing additional inputs and outputs of
|
||||
// the conditional instruction.
|
||||
std::vector<HloComputation*> new_branch_computations;
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
// If the output of the conditional contains dynamic dimension. We send
|
||||
// dynamic dimension size out by adding additional root element. A mapping
|
||||
// from the root instruction's dynamic dimension index (represented by a shape
|
||||
// index as output index and a int64 dimension number) to output index
|
||||
// (represented by an int64) is tracked for the conditional intsruction (all
|
||||
// branches should have the same mapping).
|
||||
ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
|
||||
hlo->shape());
|
||||
|
||||
bool need_rewrite = false;
|
||||
|
||||
for (int64 branch_index = 0; branch_index < hlo->branch_count();
|
||||
++branch_index) {
|
||||
std::vector<HloInstruction*> operands_to_add;
|
||||
|
||||
absl::flat_hash_map<HloInstruction*, int64>
|
||||
dynamic_size_to_operand_id_index_map;
|
||||
// Only look at branch_index + 1, the correct operand index for a
|
||||
// given branch.
|
||||
const int64 operand_index = branch_index + 1;
|
||||
|
||||
int64 operand_count =
|
||||
hlo->operand(operand_index)->shape().tuple_shapes_size();
|
||||
// Prepare to pass dynamic dimension into the new computation and add
|
||||
// dynamic dimension sizes as parameters to the new tuple.
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
|
||||
hlo, operand_index,
|
||||
[&](HloInstruction*, ShapeIndex, int64, int64,
|
||||
HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
|
||||
<< "Only tuple typed inputs can have dynamic dimension. Please "
|
||||
"file a bug against XLA team.";
|
||||
const HloInstruction* tuple_operand = hlo->operand(operand_index);
|
||||
for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
|
||||
// If the dynamic size is already an operand to the computation,
|
||||
// skip adding it to the computation input again.
|
||||
if (dynamic_size == tuple_operand->operand(i)) {
|
||||
dynamic_size_to_operand_id_index_map[dynamic_size] = i;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
|
||||
if (iter == dynamic_size_to_operand_id_index_map.end()) {
|
||||
operands_to_add.push_back(dynamic_size);
|
||||
dynamic_size_to_operand_id_index_map[dynamic_size] =
|
||||
operand_count++;
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
HloInstruction* original_input = hlo->mutable_operand(operand_index);
|
||||
HloComputation* branch_computation = hlo->branch_computation(branch_index);
|
||||
|
||||
HloComputation* new_computation = branch_computation;
|
||||
HloInstruction* new_operand = hlo->mutable_operand(operand_index);
|
||||
if (!operands_to_add.empty()) {
|
||||
TF_RET_CHECK(original_input->shape().IsTuple());
|
||||
need_rewrite = true;
|
||||
new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
new_computation,
|
||||
WidenComputation(branch_computation, new_operand->shape()));
|
||||
}
|
||||
// Set the dynamic dimensions for the newly created branch computation's
|
||||
// parameters so that the hlos inside the computation can see dynamic
|
||||
// dimensions.
|
||||
DynamicParameterBinding dynamic_parameter_binding;
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
|
||||
hlo, operand_index,
|
||||
[&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
DynamicParameterBinding::DynamicParameter dynamic_parameter{
|
||||
0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
|
||||
DynamicParameterBinding::DynamicDimension dynamic_dimension{
|
||||
0, {index}, dimension};
|
||||
TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
|
||||
dynamic_dimension));
|
||||
|
||||
return Status::OK();
|
||||
}));
|
||||
VLOG(2) << "dynamic_parameter_binding for conditional branch"
|
||||
<< dynamic_parameter_binding;
|
||||
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
|
||||
new_computation, dynamic_parameter_binding, parent_));
|
||||
std::vector<HloInstruction*> hlos_to_add_in_root;
|
||||
int64 original_tuple_count = hlo->shape().tuple_shapes_size();
|
||||
// There may be some dynamic dimensions coming out of the computation, wire
|
||||
// that into the root instruction as additional tuple elements.
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
|
||||
new_computation->root_instruction(),
|
||||
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
|
||||
DimensionConstraint) -> Status {
|
||||
TF_RET_CHECK(hlo->shape().IsTuple())
|
||||
<< "Only tuple typed conditionals can have dynamic dimension. "
|
||||
"Please file a bug against XLA team.";
|
||||
dynamic_output_mapping.mutable_element(index)->emplace(
|
||||
dim, original_tuple_count++);
|
||||
hlos_to_add_in_root.push_back(dynamic_size);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
|
||||
if (!hlos_to_add_in_root.empty()) {
|
||||
need_rewrite = true;
|
||||
HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
|
||||
new_computation->root_instruction(), hlos_to_add_in_root);
|
||||
new_computation->set_root_instruction(new_branch_root,
|
||||
/*accept_different_shape=*/true);
|
||||
}
|
||||
|
||||
new_branch_computations.push_back(new_computation);
|
||||
new_operands.push_back(new_operand);
|
||||
}
|
||||
if (!need_rewrite) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Create a new conditional with the new operations and computations.
|
||||
HloInstruction* new_conditional =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
|
||||
new_branch_computations[0]->root_instruction()->shape(),
|
||||
hlo->mutable_operand(0), new_branch_computations, new_operands));
|
||||
|
||||
HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
|
||||
new_conditional, hlo->shape().tuple_shapes_size());
|
||||
// Now set the dynamic dimensions of the newly created conditional.
|
||||
dynamic_output_mapping.ForEachElement(
|
||||
[&](const ShapeIndex& index,
|
||||
const absl::flat_hash_map<int64, int64>& dim_to_output) {
|
||||
for (auto iter : dim_to_output) {
|
||||
int64 dim = iter.first;
|
||||
int64 output_index = iter.second;
|
||||
HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::MakeScalarShape(S32), new_conditional,
|
||||
output_index));
|
||||
parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
|
||||
parent_->SetDynamicSize(new_conditional_extracted, index, dim,
|
||||
dynamic_size);
|
||||
}
|
||||
});
|
||||
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
|
||||
// Remove the original instruction even if has side-effects.
|
||||
TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
|
||||
SetVisited(*new_conditional);
|
||||
SetVisited(*new_conditional_extracted);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
@ -1314,6 +1515,23 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
|
||||
});
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
|
||||
HloInstruction* inst, const DynamicDimensionFn& fn) {
|
||||
auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
|
||||
if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
|
||||
for (auto& dynamic_dimension : iter->second) {
|
||||
HloInstruction* dynamic_size = parent_->GetDynamicSize(
|
||||
dynamic_dimension.inst, dynamic_dimension.index,
|
||||
dynamic_dimension.dim);
|
||||
CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0);
|
||||
TF_RETURN_IF_ERROR(fn(dynamic_dimension.index, dynamic_dimension.dim,
|
||||
dynamic_size,
|
||||
parent_->constraint_mapping_[dynamic_dimension]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
|
||||
HloInstruction* inst, int64 operand_index,
|
||||
const OperandDynamicDimensionFn& fn) {
|
||||
|
@ -815,6 +815,129 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
|
||||
test_dynamic_dimension();
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
|
||||
// Test the ability to trace into contional loops.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
|
||||
auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
|
||||
// In this test we set inputs to different branches to different shapes.
|
||||
auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape});
|
||||
auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape});
|
||||
auto tuple_shape_3 =
|
||||
ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape});
|
||||
|
||||
// true branch:
|
||||
//
|
||||
// Param
|
||||
// | |
|
||||
// GTE1 GTE2
|
||||
// | |
|
||||
// Tuple(ADD)
|
||||
auto true_builder = HloComputation::Builder("true");
|
||||
{
|
||||
auto true_param = true_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape_2, "param"));
|
||||
auto gte_0 = true_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(input_shape, true_param, 0));
|
||||
auto gte_1 = true_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(input_shape, true_param, 1));
|
||||
auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
input_shape, HloOpcode::kAdd, gte_0, gte_1));
|
||||
true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
|
||||
}
|
||||
HloComputation* true_branch =
|
||||
module_->AddEmbeddedComputation(true_builder.Build());
|
||||
// false branch:
|
||||
//
|
||||
// Param
|
||||
// | | |
|
||||
// GTE1 GTE2 GTE3
|
||||
// | |
|
||||
// Tuple(ADD)
|
||||
auto false_builder = HloComputation::Builder("false");
|
||||
{
|
||||
auto false_param = false_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape_3, "param"));
|
||||
auto gte_0 = false_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(input_shape, false_param, 1));
|
||||
auto gte_1 = false_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(input_shape, false_param, 2));
|
||||
auto add = false_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
input_shape, HloOpcode::kAdd, gte_0, gte_1));
|
||||
false_builder.AddInstruction(HloInstruction::CreateTuple({add}));
|
||||
}
|
||||
HloComputation* false_branch =
|
||||
module_->AddEmbeddedComputation(false_builder.Build());
|
||||
|
||||
// Entry:
|
||||
//
|
||||
// Param(bool) Param2 (tuple_2) Param3(tuple_3)
|
||||
// | | |
|
||||
// +---------Condition------------+
|
||||
auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred"));
|
||||
|
||||
auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, tuple_shape_2, "tuple_2_param"));
|
||||
auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/2, tuple_shape_3, "tuple_3_param"));
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/3, scalar_shape_, "size_param"));
|
||||
builder.AddInstruction(HloInstruction::CreateConditional(
|
||||
tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param,
|
||||
false_branch));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{3, {}},
|
||||
DynamicParameterBinding::DynamicDimension{1, {0}, 0}));
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{3, {}},
|
||||
DynamicParameterBinding::DynamicDimension{1, {1}, 0}));
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{3, {}},
|
||||
DynamicParameterBinding::DynamicDimension{2, {1}, 0}));
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{3, {}},
|
||||
DynamicParameterBinding::DynamicDimension{2, {2}, 0}));
|
||||
|
||||
TF_ASSERT_OK(RunInference());
|
||||
|
||||
HloInstruction* conditional_hlo = nullptr;
|
||||
// The while hlo has been replaced, find the new one.
|
||||
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kConditional) {
|
||||
conditional_hlo = inst;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(conditional_hlo, nullptr);
|
||||
// The original conditional shape has 1 parameters. With dynamic size passed
|
||||
// out from the computation, another element is added to the tuple.
|
||||
EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2);
|
||||
HloInstruction* add_true_branch = nullptr;
|
||||
for (HloInstruction* inst :
|
||||
conditional_hlo->true_computation()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kAdd) {
|
||||
add_true_branch = inst;
|
||||
}
|
||||
}
|
||||
EXPECT_NE(add_true_branch, nullptr);
|
||||
EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr);
|
||||
|
||||
HloInstruction* add_false_branch = nullptr;
|
||||
for (HloInstruction* inst :
|
||||
conditional_hlo->false_computation()->instructions()) {
|
||||
if (inst->opcode() == HloOpcode::kAdd) {
|
||||
add_false_branch = inst;
|
||||
}
|
||||
}
|
||||
EXPECT_NE(add_false_branch, nullptr);
|
||||
EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr);
|
||||
|
||||
EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
|
||||
// Test the ability to trace reduce window batch dimensions.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
@ -967,7 +967,7 @@ Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
|
||||
StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
VLOG(2) << "Pre DynamicPadder HLO:";
|
||||
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
// Removes dynamic dimensions on parameters if there is already a binding for
|
||||
// it. We do this because we have two different APIs to express a dynamic
|
||||
// dimension:
|
||||
|
@ -996,6 +996,56 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicConditionalDimension) {
|
||||
const string hlo_text = R"(
|
||||
HloModule module
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
true_branch {
|
||||
true_param = (s32[<=3,2]) parameter(0)
|
||||
param = s32[<=3, 2] get-tuple-element(true_param), index=0
|
||||
add = s32[<=3,2] add(param, param)
|
||||
ROOT true_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
|
||||
}
|
||||
|
||||
false_branch {
|
||||
false_param = (s32[<=3,2]) parameter(0)
|
||||
param = s32[<=3, 2] get-tuple-element(false_param), index=0
|
||||
add = s32[<=3,2] add(param, param)
|
||||
ROOT false_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = s32[3,2] parameter(0)
|
||||
size = s32[] constant(2)
|
||||
branch = pred[] constant(false)
|
||||
param_dynamic = s32[<=3, 2] set-dimension-size(param0, size), dimensions={0}
|
||||
param_tuple = (s32[<=3 ,2]) tuple(param_dynamic)
|
||||
conditional = (s32[<=3, 2], s32[<=3, 2]) conditional(branch, param_tuple, param_tuple),
|
||||
true_computation=true_branch, false_computation=false_branch
|
||||
gte0 = s32[<=3,2] get-tuple-element(conditional), index=1
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[2] reduce(gte0, init),
|
||||
dimensions={0},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
Literal operand = LiteralUtil::CreateR2<int32>({{0, 1}, {2, 3}, {4, 5}});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand},
|
||||
/*slice_dynamic_output=*/false);
|
||||
Literal expected = LiteralUtil::CreateR1<int32>({4, 8});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TEST
|
||||
|
Loading…
Reference in New Issue
Block a user