[XLA] Support hoist copy in code motion.
PiperOrigin-RevId: 329021839 Change-Id: I699c547f462466508b90289468f46e400008b4c6
This commit is contained in:
parent
155412683d
commit
ee4736e5bf
@ -2343,6 +2343,7 @@ cc_library(
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
":hlo_pass_pipeline",
|
||||
":hlo_verifier",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
||||
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -114,6 +115,8 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return 0;
|
||||
case HloOpcode::kConditional:
|
||||
return 10;
|
||||
default:
|
||||
// Assume fusion will not happen anyway if user count > 1)
|
||||
if (op->user_count() > 1) {
|
||||
@ -582,6 +585,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
|
||||
// to replace the conditional directly in the new computation.
|
||||
b_opd_use.mutable_operands().push_back(conditional);
|
||||
}
|
||||
|
||||
HloInstruction* new_root =
|
||||
computation->AddInstruction(HloInstruction::CreateTuple(operands));
|
||||
VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
|
||||
@ -592,6 +596,15 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
|
||||
}
|
||||
VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
|
||||
}
|
||||
// Update get tuple element index of the conditional.
|
||||
if (use_index != -1) {
|
||||
for (auto* user : conditional->users()) {
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement &&
|
||||
user->tuple_index() > use_index) {
|
||||
user->set_tuple_index(user->tuple_index() - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
hoisted_instructions[conditional] = b_old_root;
|
||||
int64 cp_start = 0;
|
||||
if (use_index >= 0) {
|
||||
@ -677,7 +690,7 @@ class GroupConnectedBoundaries {
|
||||
: conditional_(conditional),
|
||||
conditional_parent_(conditional->parent()),
|
||||
is_layout_sensitive_(is_layout_sensitive) {}
|
||||
// Returns true if `instruction` is worth hoisting out.
|
||||
// Returns true if `instruction` is worth hoisting.
|
||||
bool WorthHoisting(HloInstruction* instruction) {
|
||||
// This is needed for the "moving-in" transformation, to prevent the root
|
||||
// of the parent computation (which contains the conditional) to be moved
|
||||
@ -708,6 +721,7 @@ class GroupConnectedBoundaries {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kMultiply:
|
||||
@ -1070,6 +1084,7 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<HloVerifier>(false, true);
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
changed |= cleanup_changed;
|
||||
}
|
||||
|
||||
@ -728,6 +728,66 @@ ENTRY main {
|
||||
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveIdenticalInstruction
|
||||
|
||||
branch1 {
|
||||
arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0)
|
||||
constant.1 = s32[] constant(4)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0
|
||||
add.1 = s32[] add(get-tuple-element.1, constant.1)
|
||||
get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1
|
||||
slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2),
|
||||
slice={[0:4:1], [0:3:1]}
|
||||
constant.2 = f32[] constant(0.0)
|
||||
ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2)
|
||||
}
|
||||
|
||||
branch2 {
|
||||
arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0
|
||||
copy.1 = s32[] copy(get-tuple-element.3)
|
||||
get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1
|
||||
copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4)
|
||||
constant.2 = f32[] constant(0.0)
|
||||
ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1)
|
||||
tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2)
|
||||
conditional = (f32[4,3]{0,1}, s32[], f32[])
|
||||
conditional(pred.1, tuple.3, tuple.4), true_computation=branch1,
|
||||
false_computation=branch2
|
||||
get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0
|
||||
get-first-index = s32[] get-tuple-element(conditional), index=1
|
||||
get-second-index = f32[] get-tuple-element(conditional), index=2
|
||||
copy.3 = f32[4,3]{1,0} copy(get-zero-index)
|
||||
ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index,
|
||||
get-second-index)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
ASSERT_EQ(on_true->instruction_count(), 9);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 8);
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2),
|
||||
op::GetTupleElement(op::Conditional(), 0),
|
||||
op::GetTupleElement(op::Conditional(), 1))));
|
||||
}
|
||||
|
||||
} // namespace conditional_opt
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -648,7 +648,9 @@ void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
|
||||
const ShapeIndex& target_base_index) {
|
||||
CHECK(ShapeUtil::Compatible(
|
||||
ShapeUtil::GetSubshape(shape(), target_base_index),
|
||||
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
|
||||
ShapeUtil::GetSubshape(other.shape(), source_base_index)))
|
||||
<< ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
|
||||
<< ShapeUtil::GetSubshape(other.shape(), source_base_index);
|
||||
ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
|
||||
const ShapeIndex& index, T* data) {
|
||||
// Copy the data element only if index is in the
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user