diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 484e96732e1..23e8dac27fa 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4117,6 +4117,28 @@ tf_cc_test( ], ) +cc_library( + name = "root_instruction_sinker", + srcs = ["root_instruction_sinker.cc"], + hdrs = ["root_instruction_sinker.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ], +) + +tf_cc_test( + name = "root_instruction_sinker_test", + srcs = ["root_instruction_sinker_test.cc"], + deps = [ + ":hlo_matchers", + ":root_instruction_sinker", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "while_util", srcs = ["while_util.cc"], diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index e5b1756eb46..f7c66503a81 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -672,6 +672,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // interval (5-6) can be allocated separately and this buffer // doesn't waste alternate memory space within the while loop body. HloComputation* while_body = use.instruction->while_body(); + // We require while body ROOTs to be the last in the schedule. + CHECK_EQ( + instruction_schedule.at(while_body->root_instruction()) + 1, + instruction_schedule.at(use.instruction)) + << "While body ROOTs need to be the last in the schedule! " + "Please run RootInstructionSinker."; // Replace the use time with the parameter time so that we can // decide on alternate memory allocations within the while loop body // when we look at uses within the while loop body. diff --git a/tensorflow/compiler/xla/service/root_instruction_sinker.cc b/tensorflow/compiler/xla/service/root_instruction_sinker.cc new file mode 100644 index 00000000000..bee703b85e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/root_instruction_sinker.cc @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/root_instruction_sinker.h" + +#include "tensorflow/compiler/xla/service/tuple_util.h" +namespace xla { + +namespace { + +// Sinks the root of the given computation for tuple root types. +void SinkTupleRoot(HloComputation* computation) { + HloInstruction* root = computation->root_instruction(); + CHECK(root->shape().IsTuple()); + HloInstruction* new_root = TupleUtil::Duplicate(root); + // Add the new instructions to the schedule. + HloInstructionSequence& sequence = + computation->parent()->schedule().GetOrCreateSequence(computation); + for (HloInstruction* operand : new_root->operands()) { + sequence.push_back(operand); + } + sequence.push_back(new_root); + computation->set_root_instruction(new_root); +} + +// Sinks the root of the given computation for not-tuple root types. +void SinkNontupleRoot(HloComputation* computation) { + HloInstruction* root = computation->root_instruction(); + CHECK(!root->shape().IsTuple()); + HloInstruction* new_root = computation->AddInstruction( + HloInstruction::CreateBitcast(root->shape(), root)); + HloInstructionSequence& sequence = + computation->parent()->schedule().GetOrCreateSequence(computation); + sequence.push_back(new_root); + computation->set_root_instruction(new_root); +} + +} // namespace + +StatusOr RootInstructionSinker::Run(HloModule* module) { + TF_RET_CHECK(module->has_schedule()); + + bool modified = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + HloInstructionSequence& sequence = + module->schedule().GetOrCreateSequence(computation); + if (computation->root_instruction() == + sequence.instructions().at(sequence.size() - 1)) { + continue; + } + if (computation->root_instruction()->shape().IsTuple()) { + SinkTupleRoot(computation); + } else { + SinkNontupleRoot(computation); + } + modified = true; + } + return modified; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/root_instruction_sinker.h b/tensorflow/compiler/xla/service/root_instruction_sinker.h new file mode 100644 index 00000000000..d4d08870699 --- /dev/null +++ b/tensorflow/compiler/xla/service/root_instruction_sinker.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Given a scheduled HLO module, this pass sinks the ROOT of the instruction to +// the bottom of the non-fusion computations. To avoid dependency violations of +// moving the ROOT instruction, it creates a new ROOT instruction that looks +// like the following: +// - For tuple ROOT type: +// new_root = tuple(gte(old_root), gte(old_root), ...) +// - For non-tuple ROOT type: +// new_root = bitcast(old_root) +class RootInstructionSinker : public HloModulePass { + public: + ~RootInstructionSinker() override = default; + absl::string_view name() const override { return "root-instruction-sinker"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ diff --git a/tensorflow/compiler/xla/service/root_instruction_sinker_test.cc b/tensorflow/compiler/xla/service/root_instruction_sinker_test.cc new file mode 100644 index 00000000000..8a03a92b88a --- /dev/null +++ b/tensorflow/compiler/xla/service/root_instruction_sinker_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/root_instruction_sinker.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +using RootInstructionSinkerTest = HloTestBase; + +TEST_F(RootInstructionSinkerTest, TupleNoChange) { + // ROOTS are already sunk, no change performed to the module. + absl::string_view hlo_string = R"( + HloModule While, is_scheduled=true + While.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + While.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(100) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY While { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + While.condition, body=While.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto while_body = + module->entry_computation()->root_instruction()->while_body(); + int num_body_instructions = while_body->instruction_count(); + RootInstructionSinker sinker; + EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie()); + EXPECT_EQ(module->entry_computation() + ->root_instruction() + ->while_body() + ->instruction_count(), + num_body_instructions); +} + +TEST_F(RootInstructionSinkerTest, Tuple) { + // Sink tuple return type. + absl::string_view hlo_string = R"( + HloModule While, is_scheduled=true + While.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + after-all = token[] after-all() + send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1 + send-done = token[] send-done(send), channel_id=1 + } + While.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(100) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY While { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + While.condition, body=While.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + RootInstructionSinker sinker; + EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie()); + auto while_body = + module->entry_computation()->root_instruction()->while_body(); + const auto& sequence = module->schedule().sequence(while_body); + EXPECT_EQ(sequence.instructions().at(sequence.size() - 1), + while_body->root_instruction()); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(op::Tuple()), + op::GetTupleElement(op::Tuple()))); +} + +TEST_F(RootInstructionSinkerTest, NontupleNoChange) { + // ROOTS are already sunk, no change performed to the module. + absl::string_view hlo_string = R"( + HloModule Call, is_scheduled=true + Call { + param = s32[3]{0} parameter(0) + ROOT multiply = s32[3]{0} multiply(param, param) + } + ENTRY While { + constant.4 = s32[3]{0} constant({0, 1, 2}) + ROOT call = s32[3]{0} call(constant.4), to_apply=Call + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto called_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + int num_instructions = called_computation->instruction_count(); + RootInstructionSinker sinker; + EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie()); + EXPECT_EQ(module->entry_computation() + ->root_instruction() + ->called_computations()[0] + ->instruction_count(), + num_instructions); +} + +TEST_F(RootInstructionSinkerTest, Nontuple) { + // Sink a non-tuple return type. + absl::string_view hlo_string = R"( + HloModule Call, is_scheduled=true + Call { + param = s32[3]{0} parameter(0) + ROOT multiply = s32[3]{0} multiply(param, param) + after-all = token[] after-all() + send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1 + send-done = token[] send-done(send), channel_id=1 + } + ENTRY While { + constant.4 = s32[3]{0} constant({0, 1, 2}) + ROOT call = s32[3]{0} call(constant.4), to_apply=Call + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + RootInstructionSinker sinker; + EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie()); + auto called_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto& sequence = module->schedule().sequence(called_computation); + EXPECT_EQ(sequence.instructions().at(sequence.size() - 1), + called_computation->root_instruction()); + EXPECT_THAT(called_computation->root_instruction(), + op::Bitcast(op::Multiply())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h index bc5aac09f27..ee7b8be0818 100644 --- a/tensorflow/compiler/xla/service/tuple_util.h +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -39,6 +39,13 @@ class TupleUtil { static HloInstruction* AppendSuffix( HloInstruction* input_tuple, absl::Span trailing_values); + + // Generates HLO instructions that duplicates the tuple by inserting + // get-tuple-elements and a new tuple instruction. Returns the root of the + // graph of instructions generated. + static HloInstruction* Duplicate(HloInstruction* input_tuple) { + return ExtractPrefix(input_tuple, input_tuple->shape().tuple_shapes_size()); + } }; } // namespace xla