[XLA] Convert simple conditionals in to array select instructions to allow for
fusion and avoid copies in buffer assignment. PiperOrigin-RevId: 251500037
This commit is contained in:
parent
f589c2507a
commit
ebe48a8757
@ -1841,6 +1841,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
@ -55,15 +56,24 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We can always inline a 1-branch conditional due to default branch fallback.
|
// We can always inline a 1-branch conditional due to default branch fallback.
|
||||||
int branch_index = 0;
|
auto computation = conditional->parent();
|
||||||
if (conditional->branch_count() > 1) {
|
auto create_call = [&](int64 branch) {
|
||||||
if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
|
auto call = computation->AddInstruction(HloInstruction::CreateCall(
|
||||||
VLOG(2) << "Not attempting to remove conditional as its branch_index is "
|
conditional->shape(), {conditional->mutable_operand(1 + branch)},
|
||||||
"not a compile-time constant: "
|
conditional->branch_computation(branch)));
|
||||||
<< conditional->ToShortString();
|
conditional->SetupDerivedInstruction(call);
|
||||||
return false;
|
return call;
|
||||||
}
|
};
|
||||||
|
|
||||||
|
if (conditional->branch_count() == 1) {
|
||||||
|
HloInstruction* call_op = create_call(0);
|
||||||
|
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
|
||||||
|
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conditional->operand(0)->opcode() == HloOpcode::kConstant) {
|
||||||
|
int branch_index = 0;
|
||||||
if (conditional->operand(0)->shape().element_type() == PRED) {
|
if (conditional->operand(0)->shape().element_type() == PRED) {
|
||||||
branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
|
branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
|
||||||
} else {
|
} else {
|
||||||
@ -72,16 +82,83 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
|||||||
branch_index = conditional->branch_count() - 1;
|
branch_index = conditional->branch_count() - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
HloInstruction* call_op = create_call(branch_index);
|
||||||
auto computation = conditional->parent();
|
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
|
||||||
HloInstruction* call_op;
|
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
|
||||||
call_op = computation->AddInstruction(HloInstruction::CreateCall(
|
|
||||||
conditional->shape(), {conditional->mutable_operand(branch_index + 1)},
|
|
||||||
conditional->branch_computation(branch_index)));
|
|
||||||
conditional->SetupDerivedInstruction(call_op);
|
|
||||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
|
|
||||||
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
|
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto instruction_is_expensive = [](const HloInstruction* hlo) {
|
||||||
|
switch (hlo->opcode()) {
|
||||||
|
case HloOpcode::kBroadcast:
|
||||||
|
case HloOpcode::kConcatenate:
|
||||||
|
case HloOpcode::kDynamicSlice:
|
||||||
|
case HloOpcode::kDynamicUpdateSlice:
|
||||||
|
case HloOpcode::kGetTupleElement:
|
||||||
|
case HloOpcode::kReduce:
|
||||||
|
case HloOpcode::kReshape:
|
||||||
|
case HloOpcode::kPad:
|
||||||
|
case HloOpcode::kParameter:
|
||||||
|
case HloOpcode::kSlice:
|
||||||
|
case HloOpcode::kTuple:
|
||||||
|
return false;
|
||||||
|
default:
|
||||||
|
return !hlo->IsElementwise();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (conditional->branch_count() != 2 ||
|
||||||
|
conditional->operand(0)->shape().element_type() != PRED ||
|
||||||
|
absl::c_any_of(conditional->branch_computation(0)->instructions(),
|
||||||
|
instruction_is_expensive) ||
|
||||||
|
absl::c_any_of(conditional->branch_computation(1)->instructions(),
|
||||||
|
instruction_is_expensive)) {
|
||||||
|
VLOG(2)
|
||||||
|
<< "Not attempting to remove conditional as its branch_index is not a "
|
||||||
|
"compile-time constant or contains expensive instructions: "
|
||||||
|
<< conditional->ToShortString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstruction* true_call_op = create_call(0);
|
||||||
|
HloInstruction* false_call_op = create_call(1);
|
||||||
|
auto condition_broadcast = [&](const Shape& shape) {
|
||||||
|
if (ShapeUtil::IsScalar(shape)) {
|
||||||
|
return conditional->mutable_operand(0);
|
||||||
|
}
|
||||||
|
return computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
ShapeUtil::ChangeElementType(shape, PRED),
|
||||||
|
conditional->mutable_operand(0), {}));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto gte = [&](HloInstruction* hlo, int64 i) {
|
||||||
|
return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
|
hlo->shape().tuple_shapes(i), hlo, i));
|
||||||
|
};
|
||||||
|
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
|
||||||
|
[&](HloInstruction* t, HloInstruction* f) {
|
||||||
|
if (f->shape().IsArray()) {
|
||||||
|
return computation->AddInstruction(HloInstruction::CreateTernary(
|
||||||
|
f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
|
||||||
|
t, f));
|
||||||
|
}
|
||||||
|
std::vector<HloInstruction*> selects;
|
||||||
|
const int64 tuple_element_count =
|
||||||
|
ShapeUtil::TupleElementCount(f->shape());
|
||||||
|
selects.reserve(tuple_element_count);
|
||||||
|
for (int64 i = 0; i < tuple_element_count; ++i) {
|
||||||
|
selects.push_back(select(gte(t, i), gte(f, i)));
|
||||||
|
}
|
||||||
|
return computation->AddInstruction(
|
||||||
|
HloInstruction::CreateTuple(selects));
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
|
||||||
|
conditional, select(true_call_op, false_call_op)));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status());
|
||||||
|
TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
StatusOr<bool> TryRemoveUnusedConditionalOperands(
|
StatusOr<bool> TryRemoveUnusedConditionalOperands(
|
||||||
|
@ -41,10 +41,11 @@ namespace op = xla::testing::opcode_matchers;
|
|||||||
class ConditionalSimplifierTest : public HloTestBase {
|
class ConditionalSimplifierTest : public HloTestBase {
|
||||||
public:
|
public:
|
||||||
// Makes a computation that contains a conditional with constant predicate.
|
// Makes a computation that contains a conditional with constant predicate.
|
||||||
HloComputation* MakeConditional(HloModule* module);
|
HloComputation* MakeConditional(HloModule* module, bool is_constant = true);
|
||||||
};
|
};
|
||||||
|
|
||||||
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module,
|
||||||
|
bool is_constant) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
|
|
||||||
// true_computation returns param+1.
|
// true_computation returns param+1.
|
||||||
@ -83,7 +84,10 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto false_instrn = builder.AddInstruction(
|
auto false_instrn = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
is_constant
|
||||||
|
? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))
|
||||||
|
: HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}),
|
||||||
|
"cond"));
|
||||||
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
|
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
|
||||||
auto one = builder.AddInstruction(
|
auto one = builder.AddInstruction(
|
||||||
@ -104,6 +108,16 @@ TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
|
|||||||
op::Add(op::Parameter(), op::Constant()));
|
op::Add(op::Parameter(), op::Constant()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConditionalSimplifierTest, BranchGetsInlined) {
|
||||||
|
auto m = CreateNewVerifiedModule();
|
||||||
|
HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false);
|
||||||
|
ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
computation->root_instruction(),
|
||||||
|
op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()),
|
||||||
|
op::Add(op::Parameter(0), op::Constant())));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
|
TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
|
||||||
auto m = CreateNewVerifiedModule();
|
auto m = CreateNewVerifiedModule();
|
||||||
HloComputation* computation = MakeConditional(m.get());
|
HloComputation* computation = MakeConditional(m.get());
|
||||||
|
Loading…
Reference in New Issue
Block a user