[XLA] Convert simple conditionals in to array select instructions to allow for
fusion and avoid copies in buffer assignment. PiperOrigin-RevId: 251500037
This commit is contained in:
parent
f589c2507a
commit
ebe48a8757
@ -1841,6 +1841,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
@ -55,15 +56,24 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
||||
}
|
||||
|
||||
// We can always inline a 1-branch conditional due to default branch fallback.
|
||||
int branch_index = 0;
|
||||
if (conditional->branch_count() > 1) {
|
||||
if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
|
||||
VLOG(2) << "Not attempting to remove conditional as its branch_index is "
|
||||
"not a compile-time constant: "
|
||||
<< conditional->ToShortString();
|
||||
return false;
|
||||
}
|
||||
auto computation = conditional->parent();
|
||||
auto create_call = [&](int64 branch) {
|
||||
auto call = computation->AddInstruction(HloInstruction::CreateCall(
|
||||
conditional->shape(), {conditional->mutable_operand(1 + branch)},
|
||||
conditional->branch_computation(branch)));
|
||||
conditional->SetupDerivedInstruction(call);
|
||||
return call;
|
||||
};
|
||||
|
||||
if (conditional->branch_count() == 1) {
|
||||
HloInstruction* call_op = create_call(0);
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (conditional->operand(0)->opcode() == HloOpcode::kConstant) {
|
||||
int branch_index = 0;
|
||||
if (conditional->operand(0)->shape().element_type() == PRED) {
|
||||
branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
|
||||
} else {
|
||||
@ -72,16 +82,83 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
||||
branch_index = conditional->branch_count() - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto computation = conditional->parent();
|
||||
HloInstruction* call_op;
|
||||
call_op = computation->AddInstruction(HloInstruction::CreateCall(
|
||||
conditional->shape(), {conditional->mutable_operand(branch_index + 1)},
|
||||
conditional->branch_computation(branch_index)));
|
||||
conditional->SetupDerivedInstruction(call_op);
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
|
||||
HloInstruction* call_op = create_call(branch_index);
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
auto instruction_is_expensive = [](const HloInstruction* hlo) {
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kTuple:
|
||||
return false;
|
||||
default:
|
||||
return !hlo->IsElementwise();
|
||||
}
|
||||
};
|
||||
|
||||
if (conditional->branch_count() != 2 ||
|
||||
conditional->operand(0)->shape().element_type() != PRED ||
|
||||
absl::c_any_of(conditional->branch_computation(0)->instructions(),
|
||||
instruction_is_expensive) ||
|
||||
absl::c_any_of(conditional->branch_computation(1)->instructions(),
|
||||
instruction_is_expensive)) {
|
||||
VLOG(2)
|
||||
<< "Not attempting to remove conditional as its branch_index is not a "
|
||||
"compile-time constant or contains expensive instructions: "
|
||||
<< conditional->ToShortString();
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* true_call_op = create_call(0);
|
||||
HloInstruction* false_call_op = create_call(1);
|
||||
auto condition_broadcast = [&](const Shape& shape) {
|
||||
if (ShapeUtil::IsScalar(shape)) {
|
||||
return conditional->mutable_operand(0);
|
||||
}
|
||||
return computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::ChangeElementType(shape, PRED),
|
||||
conditional->mutable_operand(0), {}));
|
||||
};
|
||||
|
||||
auto gte = [&](HloInstruction* hlo, int64 i) {
|
||||
return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
hlo->shape().tuple_shapes(i), hlo, i));
|
||||
};
|
||||
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
|
||||
[&](HloInstruction* t, HloInstruction* f) {
|
||||
if (f->shape().IsArray()) {
|
||||
return computation->AddInstruction(HloInstruction::CreateTernary(
|
||||
f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
|
||||
t, f));
|
||||
}
|
||||
std::vector<HloInstruction*> selects;
|
||||
const int64 tuple_element_count =
|
||||
ShapeUtil::TupleElementCount(f->shape());
|
||||
selects.reserve(tuple_element_count);
|
||||
for (int64 i = 0; i < tuple_element_count; ++i) {
|
||||
selects.push_back(select(gte(t, i), gte(f, i)));
|
||||
}
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateTuple(selects));
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
|
||||
conditional, select(true_call_op, false_call_op)));
|
||||
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status());
|
||||
TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status());
|
||||
return true;
|
||||
}
|
||||
StatusOr<bool> TryRemoveUnusedConditionalOperands(
|
||||
|
@ -41,10 +41,11 @@ namespace op = xla::testing::opcode_matchers;
|
||||
class ConditionalSimplifierTest : public HloTestBase {
|
||||
public:
|
||||
// Makes a computation that contains a conditional with constant predicate.
|
||||
HloComputation* MakeConditional(HloModule* module);
|
||||
HloComputation* MakeConditional(HloModule* module, bool is_constant = true);
|
||||
};
|
||||
|
||||
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
||||
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module,
|
||||
bool is_constant) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
// true_computation returns param+1.
|
||||
@ -83,7 +84,10 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
||||
}
|
||||
|
||||
auto false_instrn = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
is_constant
|
||||
? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))
|
||||
: HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}),
|
||||
"cond"));
|
||||
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
|
||||
auto one = builder.AddInstruction(
|
||||
@ -104,6 +108,16 @@ TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
|
||||
op::Add(op::Parameter(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalSimplifierTest, BranchGetsInlined) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false);
|
||||
ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
computation->root_instruction(),
|
||||
op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()),
|
||||
op::Add(op::Parameter(0), op::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation* computation = MakeConditional(m.get());
|
||||
|
Loading…
Reference in New Issue
Block a user