[XLA] Support dynamic conditional input/outputs in dynamic padder.

Dynamic dimensiosn of inputs and outputs are passed as additional tuple elements.

PiperOrigin-RevId: 306277465
Change-Id: Ica1d69a813fa504fb228a63aad2226d4a51078db
This commit is contained in:
Yunxing Dai 2020-04-13 11:50:46 -07:00 committed by TensorFlower Gardener
parent b2c5581d4c
commit 519326041d
5 changed files with 398 additions and 4 deletions

View File

@ -2403,9 +2403,12 @@ cc_library(
deps = [
":hlo",
":hlo_casting_utils",
":tuple_util",
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@ -23,12 +26,45 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
namespace {
// Replace `narrow_comp` with a new computation with `wide_shape` as input.
StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
const Shape& wide_shape) {
TF_RET_CHECK(wide_shape.IsTuple());
const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
if (Shape::Equal()(wide_shape, narrow_shape)) {
// No need to widen the computation.
return narrow_comp;
}
HloComputation* wide_comp = [&]() {
HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
builder.AddInstruction(
HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
}();
HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
wide_parameter, narrow_shape.tuple_shapes_size());
HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
{truncated_parameter}, narrow_comp));
wide_comp->set_root_instruction(call_narrow_comp,
/*accept_different_shape=*/true);
TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
return wide_comp;
}
} // namespace
class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
public:
explicit DynamicDimensionInferenceVisitor(
@ -95,6 +131,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status HandleClamp(HloInstruction* hlo) override;
Status HandleConditional(HloInstruction* hlo) override;
Status HandleWhile(HloInstruction* hlo) override;
Status HandleSlice(HloInstruction* hlo) override;
@ -116,15 +154,21 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint)>;
using DynamicDimensionFn = std::function<Status(
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size,
DimensionConstraint constraint)>;
Status ForEachOperandDynamicDimension(HloInstruction* inst,
const OperandDynamicDimensionFn&);
Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
int64 operand_index,
const OperandDynamicDimensionFn&);
Status ForEachDynamicDimension(HloInstruction* inst,
const DynamicDimensionFn& fn);
// Pass through a dynamic dimension from the input to the output with the same
// value and index in the shape. This is a helper function to handle trivial
// instructions like elementwise operations.
// Pass through a dynamic dimension from the input to the output with the
// same value and index in the shape. This is a helper function to handle
// trivial instructions like elementwise operations.
Status PassThroughDynamicDimension(HloInstruction*);
// The dynamic parameter bindings of this computation.
@ -1139,6 +1183,163 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
});
}
Status DynamicDimensionInferenceVisitor::HandleConditional(
HloInstruction* hlo) {
// Conditionals are handled by producing additional inputs and outputs of
// the conditional instruction.
std::vector<HloComputation*> new_branch_computations;
std::vector<HloInstruction*> new_operands;
// If the output of the conditional contains dynamic dimension. We send
// dynamic dimension size out by adding additional root element. A mapping
// from the root instruction's dynamic dimension index (represented by a shape
// index as output index and a int64 dimension number) to output index
// (represented by an int64) is tracked for the conditional intsruction (all
// branches should have the same mapping).
ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
hlo->shape());
bool need_rewrite = false;
for (int64 branch_index = 0; branch_index < hlo->branch_count();
++branch_index) {
std::vector<HloInstruction*> operands_to_add;
absl::flat_hash_map<HloInstruction*, int64>
dynamic_size_to_operand_id_index_map;
// Only look at branch_index + 1, the correct operand index for a
// given branch.
const int64 operand_index = branch_index + 1;
int64 operand_count =
hlo->operand(operand_index)->shape().tuple_shapes_size();
// Prepare to pass dynamic dimension into the new computation and add
// dynamic dimension sizes as parameters to the new tuple.
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
hlo, operand_index,
[&](HloInstruction*, ShapeIndex, int64, int64,
HloInstruction* dynamic_size,
DimensionConstraint constraint) -> Status {
TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
<< "Only tuple typed inputs can have dynamic dimension. Please "
"file a bug against XLA team.";
const HloInstruction* tuple_operand = hlo->operand(operand_index);
for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
// If the dynamic size is already an operand to the computation,
// skip adding it to the computation input again.
if (dynamic_size == tuple_operand->operand(i)) {
dynamic_size_to_operand_id_index_map[dynamic_size] = i;
return Status::OK();
}
}
auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
if (iter == dynamic_size_to_operand_id_index_map.end()) {
operands_to_add.push_back(dynamic_size);
dynamic_size_to_operand_id_index_map[dynamic_size] =
operand_count++;
}
return Status::OK();
}));
HloInstruction* original_input = hlo->mutable_operand(operand_index);
HloComputation* branch_computation = hlo->branch_computation(branch_index);
HloComputation* new_computation = branch_computation;
HloInstruction* new_operand = hlo->mutable_operand(operand_index);
if (!operands_to_add.empty()) {
TF_RET_CHECK(original_input->shape().IsTuple());
need_rewrite = true;
new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
TF_ASSIGN_OR_RETURN(
new_computation,
WidenComputation(branch_computation, new_operand->shape()));
}
// Set the dynamic dimensions for the newly created branch computation's
// parameters so that the hlos inside the computation can see dynamic
// dimensions.
DynamicParameterBinding dynamic_parameter_binding;
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
hlo, operand_index,
[&](HloInstruction*, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
DynamicParameterBinding::DynamicParameter dynamic_parameter{
0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
DynamicParameterBinding::DynamicDimension dynamic_dimension{
0, {index}, dimension};
TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
dynamic_dimension));
return Status::OK();
}));
VLOG(2) << "dynamic_parameter_binding for conditional branch"
<< dynamic_parameter_binding;
TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
new_computation, dynamic_parameter_binding, parent_));
std::vector<HloInstruction*> hlos_to_add_in_root;
int64 original_tuple_count = hlo->shape().tuple_shapes_size();
// There may be some dynamic dimensions coming out of the computation, wire
// that into the root instruction as additional tuple elements.
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
new_computation->root_instruction(),
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
DimensionConstraint) -> Status {
TF_RET_CHECK(hlo->shape().IsTuple())
<< "Only tuple typed conditionals can have dynamic dimension. "
"Please file a bug against XLA team.";
dynamic_output_mapping.mutable_element(index)->emplace(
dim, original_tuple_count++);
hlos_to_add_in_root.push_back(dynamic_size);
return Status::OK();
}));
VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
if (!hlos_to_add_in_root.empty()) {
need_rewrite = true;
HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
new_computation->root_instruction(), hlos_to_add_in_root);
new_computation->set_root_instruction(new_branch_root,
/*accept_different_shape=*/true);
}
new_branch_computations.push_back(new_computation);
new_operands.push_back(new_operand);
}
if (!need_rewrite) {
return Status::OK();
}
// Create a new conditional with the new operations and computations.
HloInstruction* new_conditional =
hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
new_branch_computations[0]->root_instruction()->shape(),
hlo->mutable_operand(0), new_branch_computations, new_operands));
HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
new_conditional, hlo->shape().tuple_shapes_size());
// Now set the dynamic dimensions of the newly created conditional.
dynamic_output_mapping.ForEachElement(
[&](const ShapeIndex& index,
const absl::flat_hash_map<int64, int64>& dim_to_output) {
for (auto iter : dim_to_output) {
int64 dim = iter.first;
int64 output_index = iter.second;
HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(
ShapeUtil::MakeScalarShape(S32), new_conditional,
output_index));
parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
parent_->SetDynamicSize(new_conditional_extracted, index, dim,
dynamic_size);
}
});
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
// Remove the original instruction even if has side-effects.
TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
SetVisited(*new_conditional);
SetVisited(*new_conditional_extracted);
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo,
@ -1314,6 +1515,23 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
});
}
Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
HloInstruction* inst, const DynamicDimensionFn& fn) {
auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
for (auto& dynamic_dimension : iter->second) {
HloInstruction* dynamic_size = parent_->GetDynamicSize(
dynamic_dimension.inst, dynamic_dimension.index,
dynamic_dimension.dim);
CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0);
TF_RETURN_IF_ERROR(fn(dynamic_dimension.index, dynamic_dimension.dim,
dynamic_size,
parent_->constraint_mapping_[dynamic_dimension]));
}
}
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
HloInstruction* inst, int64 operand_index,
const OperandDynamicDimensionFn& fn) {

View File

@ -815,6 +815,129 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) {
test_dynamic_dimension();
}
TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
// Test the ability to trace into contional loops.
auto builder = HloComputation::Builder(TestName());
auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
// In this test we set inputs to different branches to different shapes.
auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape});
auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape});
auto tuple_shape_3 =
ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape});
// true branch:
//
// Param
// | |
// GTE1 GTE2
// | |
// Tuple(ADD)
auto true_builder = HloComputation::Builder("true");
{
auto true_param = true_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape_2, "param"));
auto gte_0 = true_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(input_shape, true_param, 0));
auto gte_1 = true_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(input_shape, true_param, 1));
auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
input_shape, HloOpcode::kAdd, gte_0, gte_1));
true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
}
HloComputation* true_branch =
module_->AddEmbeddedComputation(true_builder.Build());
// false branch:
//
// Param
// | | |
// GTE1 GTE2 GTE3
// | |
// Tuple(ADD)
auto false_builder = HloComputation::Builder("false");
{
auto false_param = false_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape_3, "param"));
auto gte_0 = false_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(input_shape, false_param, 1));
auto gte_1 = false_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(input_shape, false_param, 2));
auto add = false_builder.AddInstruction(HloInstruction::CreateBinary(
input_shape, HloOpcode::kAdd, gte_0, gte_1));
false_builder.AddInstruction(HloInstruction::CreateTuple({add}));
}
HloComputation* false_branch =
module_->AddEmbeddedComputation(false_builder.Build());
// Entry:
//
// Param(bool) Param2 (tuple_2) Param3(tuple_3)
// | | |
// +---------Condition------------+
auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred"));
auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, tuple_shape_2, "tuple_2_param"));
auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, tuple_shape_3, "tuple_3_param"));
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/3, scalar_shape_, "size_param"));
builder.AddInstruction(HloInstruction::CreateConditional(
tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param,
false_branch));
module_->AddEntryComputation(builder.Build());
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{3, {}},
DynamicParameterBinding::DynamicDimension{1, {0}, 0}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{3, {}},
DynamicParameterBinding::DynamicDimension{1, {1}, 0}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{3, {}},
DynamicParameterBinding::DynamicDimension{2, {1}, 0}));
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{3, {}},
DynamicParameterBinding::DynamicDimension{2, {2}, 0}));
TF_ASSERT_OK(RunInference());
HloInstruction* conditional_hlo = nullptr;
// The while hlo has been replaced, find the new one.
for (HloInstruction* inst : module_->entry_computation()->instructions()) {
if (inst->opcode() == HloOpcode::kConditional) {
conditional_hlo = inst;
}
}
ASSERT_NE(conditional_hlo, nullptr);
// The original conditional shape has 1 parameters. With dynamic size passed
// out from the computation, another element is added to the tuple.
EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2);
HloInstruction* add_true_branch = nullptr;
for (HloInstruction* inst :
conditional_hlo->true_computation()->instructions()) {
if (inst->opcode() == HloOpcode::kAdd) {
add_true_branch = inst;
}
}
EXPECT_NE(add_true_branch, nullptr);
EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr);
HloInstruction* add_false_branch = nullptr;
for (HloInstruction* inst :
conditional_hlo->false_computation()->instructions()) {
if (inst->opcode() == HloOpcode::kAdd) {
add_false_branch = inst;
}
}
EXPECT_NE(add_false_branch, nullptr);
EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr);
EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr);
}
TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
// Test the ability to trace reduce window batch dimensions.
auto builder = HloComputation::Builder(TestName());

View File

@ -967,7 +967,7 @@ Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
StatusOr<bool> DynamicPadder::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "Pre DynamicPadder HLO:";
XLA_VLOG_LINES(2, module->ToString());
// Removes dynamic dimensions on parameters if there is already a binding for
// it. We do this because we have two different APIs to express a dynamic
// dimension:

View File

@ -996,6 +996,56 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicConditionalDimension) {
const string hlo_text = R"(
HloModule module
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
true_branch {
true_param = (s32[<=3,2]) parameter(0)
param = s32[<=3, 2] get-tuple-element(true_param), index=0
add = s32[<=3,2] add(param, param)
ROOT true_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
}
false_branch {
false_param = (s32[<=3,2]) parameter(0)
param = s32[<=3, 2] get-tuple-element(false_param), index=0
add = s32[<=3,2] add(param, param)
ROOT false_tuple = (s32[<=3,2], s32[<=3,2]) tuple(add, add)
}
ENTRY entry {
param0 = s32[3,2] parameter(0)
size = s32[] constant(2)
branch = pred[] constant(false)
param_dynamic = s32[<=3, 2] set-dimension-size(param0, size), dimensions={0}
param_tuple = (s32[<=3 ,2]) tuple(param_dynamic)
conditional = (s32[<=3, 2], s32[<=3, 2]) conditional(branch, param_tuple, param_tuple),
true_computation=true_branch, false_computation=false_branch
gte0 = s32[<=3,2] get-tuple-element(conditional), index=1
init = s32[] constant(0)
ROOT reduce = s32[2] reduce(gte0, init),
dimensions={0},
to_apply=update_s32
}
)";
Literal operand = LiteralUtil::CreateR2<int32>({{0, 1}, {2, 3}, {4, 5}});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand},
/*slice_dynamic_output=*/false);
Literal expected = LiteralUtil::CreateR1<int32>({4, 8});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicTupleSort) {
const string hlo_text = R"(
HloModule TEST