[XLA] Support hoist copy in code motion.

PiperOrigin-RevId: 329021839
Change-Id: I699c547f462466508b90289468f46e400008b4c6
This commit is contained in:
Yunxing Dai 2020-08-28 15:37:35 -07:00 committed by TensorFlower Gardener
parent 155412683d
commit ee4736e5bf
4 changed files with 80 additions and 2 deletions

View File

@ -2343,6 +2343,7 @@ cc_library(
":hlo_dce",
":hlo_pass",
":hlo_pass_pipeline",
":hlo_verifier",
":tuple_simplifier",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -114,6 +115,8 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
case HloOpcode::kConstant:
case HloOpcode::kGetTupleElement:
return 0;
case HloOpcode::kConditional:
return 10;
default:
// Assume fusion will not happen anyway if user count > 1)
if (op->user_count() > 1) {
@ -582,6 +585,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
// to replace the conditional directly in the new computation.
b_opd_use.mutable_operands().push_back(conditional);
}
HloInstruction* new_root =
computation->AddInstruction(HloInstruction::CreateTuple(operands));
VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
@ -592,6 +596,15 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
}
VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
}
// Update get tuple element index of the conditional.
if (use_index != -1) {
for (auto* user : conditional->users()) {
if (user->opcode() == HloOpcode::kGetTupleElement &&
user->tuple_index() > use_index) {
user->set_tuple_index(user->tuple_index() - 1);
}
}
}
hoisted_instructions[conditional] = b_old_root;
int64 cp_start = 0;
if (use_index >= 0) {
@ -677,7 +690,7 @@ class GroupConnectedBoundaries {
: conditional_(conditional),
conditional_parent_(conditional->parent()),
is_layout_sensitive_(is_layout_sensitive) {}
// Returns true if `instruction` is worth hoisting out.
// Returns true if `instruction` is worth hoisting.
bool WorthHoisting(HloInstruction* instruction) {
// This is needed for the "moving-in" transformation, to prevent the root
// of the parent computation (which contains the conditional) to be moved
@ -708,6 +721,7 @@ class GroupConnectedBoundaries {
case HloOpcode::kAllReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
case HloOpcode::kCopy:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
@ -1070,6 +1084,7 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<HloVerifier>(false, true);
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
changed |= cleanup_changed;
}

View File

@ -728,6 +728,66 @@ ENTRY main {
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
branch1 {
arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0)
constant.1 = s32[] constant(4)
get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0
add.1 = s32[] add(get-tuple-element.1, constant.1)
get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1
slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2),
slice={[0:4:1], [0:3:1]}
constant.2 = f32[] constant(0.0)
ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2)
}
branch2 {
arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0
copy.1 = s32[] copy(get-tuple-element.3)
get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1
copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4)
constant.2 = f32[] constant(0.0)
ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2)
}
ENTRY main {
pred.1 = pred[] parameter(0)
tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1)
tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2)
conditional = (f32[4,3]{0,1}, s32[], f32[])
conditional(pred.1, tuple.3, tuple.4), true_computation=branch1,
false_computation=branch2
get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0
get-first-index = s32[] get-tuple-element(conditional), index=1
get-second-index = f32[] get-tuple-element(conditional), index=2
copy.3 = f32[4,3]{1,0} copy(get-zero-index)
ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index,
get-second-index)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
VLOG(1) << module->ToString();
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 9);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 8);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2),
op::GetTupleElement(op::Conditional(), 0),
op::GetTupleElement(op::Conditional(), 1))));
}
} // namespace conditional_opt
} // namespace xla

View File

@ -648,7 +648,9 @@ void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
const ShapeIndex& target_base_index) {
CHECK(ShapeUtil::Compatible(
ShapeUtil::GetSubshape(shape(), target_base_index),
ShapeUtil::GetSubshape(other.shape(), source_base_index)));
ShapeUtil::GetSubshape(other.shape(), source_base_index)))
<< ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
<< ShapeUtil::GetSubshape(other.shape(), source_base_index);
ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
const ShapeIndex& index, T* data) {
// Copy the data element only if index is in the