From 519326041d9cbe8f4e3e7a9f02a264062a1000e2 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Mon, 13 Apr 2020 11:50:46 -0700 Subject: [PATCH] [XLA] Support dynamic conditional input/outputs in dynamic padder. Dynamic dimensiosn of inputs and outputs are passed as additional tuple elements. PiperOrigin-RevId: 306277465 Change-Id: Ica1d69a813fa504fb228a63aad2226d4a51078db --- tensorflow/compiler/xla/service/BUILD | 3 + .../service/dynamic_dimension_inference.cc | 224 +++++++++++++++++- .../dynamic_dimension_inference_test.cc | 123 ++++++++++ .../compiler/xla/service/dynamic_padder.cc | 2 +- .../xla/service/dynamic_padder_test.cc | 50 ++++ 5 files changed, 398 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 255444fb53c..3f06c6a29ce 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2403,9 +2403,12 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", + ":tuple_util", ":while_util", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 5f5d02f92b6..a103b555df6 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include + +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -23,12 +26,45 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" namespace xla { +namespace { +// Replace `narrow_comp` with a new computation with `wide_shape` as input. +StatusOr WidenComputation(HloComputation* narrow_comp, + const Shape& wide_shape) { + TF_RET_CHECK(wide_shape.IsTuple()); + const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape(); + if (Shape::Equal()(wide_shape, narrow_shape)) { + // No need to widen the computation. + return narrow_comp; + } + HloComputation* wide_comp = [&]() { + HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + return narrow_comp->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* wide_parameter = wide_comp->parameter_instruction(0); + HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix( + wide_parameter, narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_comp = wide_comp->AddInstruction( + HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(), + {truncated_parameter}, narrow_comp)); + wide_comp->set_root_instruction(call_narrow_comp, + /*accept_different_shape=*/true); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status()); + return wide_comp; +} +} // namespace + class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { public: explicit DynamicDimensionInferenceVisitor( @@ -95,6 +131,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* hlo) override; + Status HandleConditional(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; Status HandleSlice(HloInstruction* hlo) override; @@ -116,15 +154,21 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { int64 operand_index, HloInstruction* dynamic_size, DimensionConstraint constraint)>; + using DynamicDimensionFn = std::function; + Status ForEachOperandDynamicDimension(HloInstruction* inst, const OperandDynamicDimensionFn&); Status ForEachDynamicDimensionInOperand(HloInstruction* inst, int64 operand_index, const OperandDynamicDimensionFn&); + Status ForEachDynamicDimension(HloInstruction* inst, + const DynamicDimensionFn& fn); - // Pass through a dynamic dimension from the input to the output with the same - // value and index in the shape. This is a helper function to handle trivial - // instructions like elementwise operations. + // Pass through a dynamic dimension from the input to the output with the + // same value and index in the shape. This is a helper function to handle + // trivial instructions like elementwise operations. Status PassThroughDynamicDimension(HloInstruction*); // The dynamic parameter bindings of this computation. @@ -1139,6 +1183,163 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) { }); } +Status DynamicDimensionInferenceVisitor::HandleConditional( + HloInstruction* hlo) { + // Conditionals are handled by producing additional inputs and outputs of + // the conditional instruction. + std::vector new_branch_computations; + std::vector new_operands; + // If the output of the conditional contains dynamic dimension. We send + // dynamic dimension size out by adding additional root element. A mapping + // from the root instruction's dynamic dimension index (represented by a shape + // index as output index and a int64 dimension number) to output index + // (represented by an int64) is tracked for the conditional intsruction (all + // branches should have the same mapping). + ShapeTree> dynamic_output_mapping( + hlo->shape()); + + bool need_rewrite = false; + + for (int64 branch_index = 0; branch_index < hlo->branch_count(); + ++branch_index) { + std::vector operands_to_add; + + absl::flat_hash_map + dynamic_size_to_operand_id_index_map; + // Only look at branch_index + 1, the correct operand index for a + // given branch. + const int64 operand_index = branch_index + 1; + + int64 operand_count = + hlo->operand(operand_index)->shape().tuple_shapes_size(); + // Prepare to pass dynamic dimension into the new computation and add + // dynamic dimension sizes as parameters to the new tuple. + TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( + hlo, operand_index, + [&](HloInstruction*, ShapeIndex, int64, int64, + HloInstruction* dynamic_size, + DimensionConstraint constraint) -> Status { + TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple()) + << "Only tuple typed inputs can have dynamic dimension. Please " + "file a bug against XLA team."; + const HloInstruction* tuple_operand = hlo->operand(operand_index); + for (int64 i = 0; i < tuple_operand->operand_count(); ++i) { + // If the dynamic size is already an operand to the computation, + // skip adding it to the computation input again. + if (dynamic_size == tuple_operand->operand(i)) { + dynamic_size_to_operand_id_index_map[dynamic_size] = i; + return Status::OK(); + } + } + auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size); + if (iter == dynamic_size_to_operand_id_index_map.end()) { + operands_to_add.push_back(dynamic_size); + dynamic_size_to_operand_id_index_map[dynamic_size] = + operand_count++; + } + return Status::OK(); + })); + + HloInstruction* original_input = hlo->mutable_operand(operand_index); + HloComputation* branch_computation = hlo->branch_computation(branch_index); + + HloComputation* new_computation = branch_computation; + HloInstruction* new_operand = hlo->mutable_operand(operand_index); + if (!operands_to_add.empty()) { + TF_RET_CHECK(original_input->shape().IsTuple()); + need_rewrite = true; + new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add); + TF_ASSIGN_OR_RETURN( + new_computation, + WidenComputation(branch_computation, new_operand->shape())); + } + // Set the dynamic dimensions for the newly created branch computation's + // parameters so that the hlos inside the computation can see dynamic + // dimensions. + DynamicParameterBinding dynamic_parameter_binding; + TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( + hlo, operand_index, + [&](HloInstruction*, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size, + DimensionConstraint constraint) { + DynamicParameterBinding::DynamicParameter dynamic_parameter{ + 0, {dynamic_size_to_operand_id_index_map[dynamic_size]}}; + DynamicParameterBinding::DynamicDimension dynamic_dimension{ + 0, {index}, dimension}; + TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter, + dynamic_dimension)); + + return Status::OK(); + })); + VLOG(2) << "dynamic_parameter_binding for conditional branch" + << dynamic_parameter_binding; + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + new_computation, dynamic_parameter_binding, parent_)); + std::vector hlos_to_add_in_root; + int64 original_tuple_count = hlo->shape().tuple_shapes_size(); + // There may be some dynamic dimensions coming out of the computation, wire + // that into the root instruction as additional tuple elements. + TF_RETURN_IF_ERROR(ForEachDynamicDimension( + new_computation->root_instruction(), + [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size, + DimensionConstraint) -> Status { + TF_RET_CHECK(hlo->shape().IsTuple()) + << "Only tuple typed conditionals can have dynamic dimension. " + "Please file a bug against XLA team."; + dynamic_output_mapping.mutable_element(index)->emplace( + dim, original_tuple_count++); + hlos_to_add_in_root.push_back(dynamic_size); + return Status::OK(); + })); + + VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size(); + if (!hlos_to_add_in_root.empty()) { + need_rewrite = true; + HloInstruction* new_branch_root = TupleUtil::AppendSuffix( + new_computation->root_instruction(), hlos_to_add_in_root); + new_computation->set_root_instruction(new_branch_root, + /*accept_different_shape=*/true); + } + + new_branch_computations.push_back(new_computation); + new_operands.push_back(new_operand); + } + if (!need_rewrite) { + return Status::OK(); + } + // Create a new conditional with the new operations and computations. + HloInstruction* new_conditional = + hlo->parent()->AddInstruction(HloInstruction::CreateConditional( + new_branch_computations[0]->root_instruction()->shape(), + hlo->mutable_operand(0), new_branch_computations, new_operands)); + + HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix( + new_conditional, hlo->shape().tuple_shapes_size()); + // Now set the dynamic dimensions of the newly created conditional. + dynamic_output_mapping.ForEachElement( + [&](const ShapeIndex& index, + const absl::flat_hash_map& dim_to_output) { + for (auto iter : dim_to_output) { + int64 dim = iter.first; + int64 output_index = iter.second; + HloInstruction* dynamic_size = hlo->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeScalarShape(S32), new_conditional, + output_index)); + parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size); + parent_->SetDynamicSize(new_conditional_extracted, index, dim, + dynamic_size); + } + }); + + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted)); + // Remove the original instruction even if has side-effects. + TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo)); + SetVisited(*new_conditional); + SetVisited(*new_conditional_extracted); + return Status::OK(); +} + Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, @@ -1314,6 +1515,23 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { }); } +Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension( + HloInstruction* inst, const DynamicDimensionFn& fn) { + auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst); + if (iter != parent_->per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0); + TF_RETURN_IF_ERROR(fn(dynamic_dimension.index, dynamic_dimension.dim, + dynamic_size, + parent_->constraint_mapping_[dynamic_dimension])); + } + } + return Status::OK(); +} + Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand( HloInstruction* inst, int64 operand_index, const OperandDynamicDimensionFn& fn) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index dbe57985fd4..dc295669fa9 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -815,6 +815,129 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) { test_dynamic_dimension(); } +TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) { + // Test the ability to trace into contional loops. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + // In this test we set inputs to different branches to different shapes. + auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape}); + auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape}); + auto tuple_shape_3 = + ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape}); + + // true branch: + // + // Param + // | | + // GTE1 GTE2 + // | | + // Tuple(ADD) + auto true_builder = HloComputation::Builder("true"); + { + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_2, "param")); + auto gte_0 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, true_param, 0)); + auto gte_1 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + input_shape, HloOpcode::kAdd, gte_0, gte_1)); + true_builder.AddInstruction(HloInstruction::CreateTuple({add})); + } + HloComputation* true_branch = + module_->AddEmbeddedComputation(true_builder.Build()); + // false branch: + // + // Param + // | | | + // GTE1 GTE2 GTE3 + // | | + // Tuple(ADD) + auto false_builder = HloComputation::Builder("false"); + { + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_3, "param")); + auto gte_0 = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, false_param, 1)); + auto gte_1 = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, false_param, 2)); + auto add = false_builder.AddInstruction(HloInstruction::CreateBinary( + input_shape, HloOpcode::kAdd, gte_0, gte_1)); + false_builder.AddInstruction(HloInstruction::CreateTuple({add})); + } + HloComputation* false_branch = + module_->AddEmbeddedComputation(false_builder.Build()); + + // Entry: + // + // Param(bool) Param2 (tuple_2) Param3(tuple_3) + // | | | + // +---------Condition------------+ + auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred")); + + auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, tuple_shape_2, "tuple_2_param")); + auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, tuple_shape_3, "tuple_3_param")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + builder.AddInstruction(HloInstruction::CreateConditional( + tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param, + false_branch)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{1, {0}, 0})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{1, {1}, 0})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{2, {1}, 0})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{2, {2}, 0})); + + TF_ASSERT_OK(RunInference()); + + HloInstruction* conditional_hlo = nullptr; + // The while hlo has been replaced, find the new one. + for (HloInstruction* inst : module_->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kConditional) { + conditional_hlo = inst; + } + } + ASSERT_NE(conditional_hlo, nullptr); + // The original conditional shape has 1 parameters. With dynamic size passed + // out from the computation, another element is added to the tuple. + EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2); + HloInstruction* add_true_branch = nullptr; + for (HloInstruction* inst : + conditional_hlo->true_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kAdd) { + add_true_branch = inst; + } + } + EXPECT_NE(add_true_branch, nullptr); + EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr); + + HloInstruction* add_false_branch = nullptr; + for (HloInstruction* inst : + conditional_hlo->false_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kAdd) { + add_false_branch = inst; + } + } + EXPECT_NE(add_false_branch, nullptr); + EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr); + + EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr); +} + TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { // Test the ability to trace reduce window batch dimensions. auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index afff48783b7..e0fe9c08d0a 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -967,7 +967,7 @@ Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) { StatusOr DynamicPadder::Run(HloModule* module) { bool changed = false; VLOG(2) << "Pre DynamicPadder HLO:"; - + XLA_VLOG_LINES(2, module->ToString()); // Removes dynamic dimensions on parameters if there is already a binding for // it. We do this because we have two different APIs to express a dynamic // dimension: diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index d3b68266b4f..c937bf2c723 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -996,6 +996,56 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicConditionalDimension) { + const string hlo_text = R"( +HloModule module + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +true_branch { + true_param = (s32[<=3,2]) parameter(0) + param = s32[<=3, 2] get-tuple-element(true_param), index=0 + add = s32[<=3,2] add(param, param) + ROOT true_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add) +} + +false_branch { + false_param = (s32[<=3,2]) parameter(0) + param = s32[<=3, 2] get-tuple-element(false_param), index=0 + add = s32[<=3,2] add(param, param) + ROOT false_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add) +} + +ENTRY entry { + param0 = s32[3,2] parameter(0) + size = s32[] constant(2) + branch = pred[] constant(false) + param_dynamic = s32[<=3, 2] set-dimension-size(param0, size), dimensions={0} + param_tuple = (s32[<=3 ,2]) tuple(param_dynamic) + conditional = (s32[<=3, 2], s32[<=3, 2]) conditional(branch, param_tuple, param_tuple), + true_computation=true_branch, false_computation=false_branch + gte0 = s32[<=3,2] get-tuple-element(conditional), index=1 + init = s32[] constant(0) + ROOT reduce = s32[2] reduce(gte0, init), + dimensions={0}, + to_apply=update_s32 +} +)"; + + Literal operand = LiteralUtil::CreateR2({{0, 1}, {2, 3}, {4, 5}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}, + /*slice_dynamic_output=*/false); + Literal expected = LiteralUtil::CreateR1({4, 8}); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DynamicTupleSort) { const string hlo_text = R"( HloModule TEST