[XLA] Convert simple conditionals in to array select instructions to allow for

fusion and avoid copies in buffer assignment.

PiperOrigin-RevId: 251500037
This commit is contained in:
Blake Hechtman 2019-06-04 13:09:26 -07:00 committed by TensorFlower Gardener
parent f589c2507a
commit ebe48a8757
3 changed files with 112 additions and 20 deletions

View File

@ -1841,6 +1841,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
@ -55,15 +56,24 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
}
// We can always inline a 1-branch conditional due to default branch fallback.
int branch_index = 0;
if (conditional->branch_count() > 1) {
if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
VLOG(2) << "Not attempting to remove conditional as its branch_index is "
"not a compile-time constant: "
<< conditional->ToShortString();
return false;
}
auto computation = conditional->parent();
auto create_call = [&](int64 branch) {
auto call = computation->AddInstruction(HloInstruction::CreateCall(
conditional->shape(), {conditional->mutable_operand(1 + branch)},
conditional->branch_computation(branch)));
conditional->SetupDerivedInstruction(call);
return call;
};
if (conditional->branch_count() == 1) {
HloInstruction* call_op = create_call(0);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
return true;
}
if (conditional->operand(0)->opcode() == HloOpcode::kConstant) {
int branch_index = 0;
if (conditional->operand(0)->shape().element_type() == PRED) {
branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
} else {
@ -72,16 +82,83 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
branch_index = conditional->branch_count() - 1;
}
}
}
auto computation = conditional->parent();
HloInstruction* call_op;
call_op = computation->AddInstruction(HloInstruction::CreateCall(
conditional->shape(), {conditional->mutable_operand(branch_index + 1)},
conditional->branch_computation(branch_index)));
conditional->SetupDerivedInstruction(call_op);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
HloInstruction* call_op = create_call(branch_index);
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
return true;
}
auto instruction_is_expensive = [](const HloInstruction* hlo) {
switch (hlo->opcode()) {
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReduce:
case HloOpcode::kReshape:
case HloOpcode::kPad:
case HloOpcode::kParameter:
case HloOpcode::kSlice:
case HloOpcode::kTuple:
return false;
default:
return !hlo->IsElementwise();
}
};
if (conditional->branch_count() != 2 ||
conditional->operand(0)->shape().element_type() != PRED ||
absl::c_any_of(conditional->branch_computation(0)->instructions(),
instruction_is_expensive) ||
absl::c_any_of(conditional->branch_computation(1)->instructions(),
instruction_is_expensive)) {
VLOG(2)
<< "Not attempting to remove conditional as its branch_index is not a "
"compile-time constant or contains expensive instructions: "
<< conditional->ToShortString();
return false;
}
HloInstruction* true_call_op = create_call(0);
HloInstruction* false_call_op = create_call(1);
auto condition_broadcast = [&](const Shape& shape) {
if (ShapeUtil::IsScalar(shape)) {
return conditional->mutable_operand(0);
}
return computation->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(shape, PRED),
conditional->mutable_operand(0), {}));
};
auto gte = [&](HloInstruction* hlo, int64 i) {
return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
hlo->shape().tuple_shapes(i), hlo, i));
};
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
[&](HloInstruction* t, HloInstruction* f) {
if (f->shape().IsArray()) {
return computation->AddInstruction(HloInstruction::CreateTernary(
f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
t, f));
}
std::vector<HloInstruction*> selects;
const int64 tuple_element_count =
ShapeUtil::TupleElementCount(f->shape());
selects.reserve(tuple_element_count);
for (int64 i = 0; i < tuple_element_count; ++i) {
selects.push_back(select(gte(t, i), gte(f, i)));
}
return computation->AddInstruction(
HloInstruction::CreateTuple(selects));
};
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
conditional, select(true_call_op, false_call_op)));
TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status());
TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status());
return true;
}
StatusOr<bool> TryRemoveUnusedConditionalOperands(

View File

@ -41,10 +41,11 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloTestBase {
public:
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
HloComputation* MakeConditional(HloModule* module, bool is_constant = true);
};
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module,
bool is_constant) {
HloComputation::Builder builder(TestName());
// true_computation returns param+1.
@ -83,7 +84,10 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
}
auto false_instrn = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
is_constant
? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))
: HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}),
"cond"));
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
auto one = builder.AddInstruction(
@ -104,6 +108,16 @@ TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
op::Add(op::Parameter(), op::Constant()));
}
TEST_F(ConditionalSimplifierTest, BranchGetsInlined) {
auto m = CreateNewVerifiedModule();
HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false);
ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
EXPECT_THAT(
computation->root_instruction(),
op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()),
op::Add(op::Parameter(0), op::Constant())));
}
TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
auto m = CreateNewVerifiedModule();
HloComputation* computation = MakeConditional(m.get());