[XLA:CPU/GPU] Remove extra copy insertion to work around buffer assignment's local view

BufferAssignment is module-scoped now, so the extra copies shouldn't be needed.

This means that parameter instructions can be at the root of a computation now,
adjust code that relies on that not happening.

PiperOrigin-RevId: 252002504
This commit is contained in:
Benjamin Kramer 2019-06-07 00:32:19 -07:00 committed by TensorFlower Gardener
parent 1ad04efcee
commit bfc8733ffb
10 changed files with 17 additions and 342 deletions

View File

@ -1222,57 +1222,4 @@ bool IsWhileBody(const HloComputation* computation,
} }
} // namespace } // namespace
/* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment(
HloModule* module) {
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
HloDataflowAnalysis::Run(*module));
bool changed = false;
// If a buffer live out of a computation is a constant, a parameter, or not
// defined in the computation, then copy it to account for the limited
// computation-scoped analysis in buffer assignment. An exception to this rule
// is the while body which is handled properly without copies.
for (HloComputation* computation : module->computations()) {
if (computation == module->entry_computation() ||
IsWhileBody(computation, *call_graph)) {
continue;
}
HloInstruction* root = computation->root_instruction();
ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false);
bool copy_root = false;
for (const auto& pair : dataflow->GetInstructionValueSet(root)) {
const ShapeIndex& index = pair.first;
const HloValueSet& value_set = pair.second;
for (const HloValue* value : value_set.values()) {
HloInstruction* def = value->defining_instruction();
if (def->parent() != computation ||
def->opcode() == HloOpcode::kConstant ||
def->opcode() == HloOpcode::kParameter) {
*indices_to_copy.mutable_element(index) = true;
copy_root = true;
}
}
}
if (copy_root) {
TF_ASSIGN_OR_RETURN(
HloInstruction * root_copy,
computation->DeepCopyInstruction(root, &indices_to_copy));
computation->set_root_instruction(root_copy);
changed = true;
}
}
TupleSimplifier tuple_simplifier;
HloDCE dce;
TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
tuple_simplifier.Run(module));
TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module));
return changed || tuple_simplifier_changed || dce_changed;
}
} // namespace xla } // namespace xla

View File

@ -60,17 +60,6 @@ class CopyInsertion : public HloModulePass {
// (copies were inserted). // (copies were inserted).
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
// The CPU and GPU backend need additional copies added due to deficiencies in
// buffer assignment. Specifically, copies are needed for constants live-out
// of computations, and for values which are live-in and live-out of the same
// computation. These copies are needed because buffer-assignment uses a
// computation-scoped analyis (TuplePointsToAnalysis) and has limited
// visibility across computation boundaries. This method adds these necessary
// copies. Returns whether the module was modified.
//
// TODO(b/62548313): Remove this when buffer assignment is module-scoped.
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
// Try to remove as many copies from the module as possible without // Try to remove as many copies from the module as possible without
// introducing live range interference. Only copy instructions that are // introducing live range interference. Only copy instructions that are
// eligible for copy elision are considered for removal. // eligible for copy elision are considered for removal.

View File

@ -76,7 +76,6 @@ cc_library(
":compiler_functor", ":compiler_functor",
":buffer_info_util", ":buffer_info_util",
":conv_canonicalization", ":conv_canonicalization",
":cpu_copy_insertion",
":cpu_executable", ":cpu_executable",
":cpu_hlo_support_checker", ":cpu_hlo_support_checker",
":cpu_instruction_fusion", ":cpu_instruction_fusion",
@ -92,6 +91,7 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
":target_machine_features", ":target_machine_features",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:map_inliner",
@ -956,18 +956,6 @@ cc_library(
], ],
) )
cc_library(
name = "cpu_copy_insertion",
srcs = ["cpu_copy_insertion.cc"],
hdrs = ["cpu_copy_insertion.h"],
deps = [
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "vector_support_library", name = "vector_support_library",
srcs = ["vector_support_library.cc"], srcs = ["vector_support_library.cc"],
@ -986,26 +974,6 @@ cc_library(
], ],
) )
tf_cc_test(
name = "cpu_copy_insertion_test",
srcs = ["cpu_copy_insertion_test.cc"],
deps = [
":cpu_copy_insertion",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library( cc_library(
name = "cpu_hlo_support_checker", name = "cpu_hlo_support_checker",
srcs = ["cpu_hlo_support_checker.cc"], srcs = ["cpu_hlo_support_checker.cc"],

View File

@ -56,10 +56,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/conditional_to_select.h" #include "tensorflow/compiler/xla/service/conditional_to_select.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
@ -402,7 +402,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
// interfering with the rewrites. // interfering with the rewrites.
pipeline.AddPass<HloDCE>(); pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>(); pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>(); pipeline.AddPass<CopyInsertion>();
pipeline.AddPass<HloDCE>(); pipeline.AddPass<HloDCE>();
return pipeline.Run(module).status(); return pipeline.Run(module).status();
} }

View File

@ -1,43 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include <memory>
#include <set>
#include <vector>
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
StatusOr<bool> CpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module));
// The CPU backend needs additional copies added due to deficiencies in
// buffer assignment.
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
CopyInsertion::AddCopiesForBufferAssignment(module));
return generic_changed || buffer_assignment_changed;
}
} // namespace xla

View File

@ -1,42 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// Besides the modifications made by the generic xla::CopyInsertion, this
// CPU-specific copy insertion pass also adds copies to values live out of
// computations satisfying certain conditions (defined by constant or parameter,
// etc). This is necessary because of deficiencies of buffer
// assignment. Specifically, buffer assignment is computation-scoped and does
// not recognized aliasing between arguments and outputs of computations.
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
class CpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_

View File

@ -1,139 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
int64 CountCopies(const HloComputation& computation) {
int64 count = 0;
for (const auto& instruction : computation.instructions()) {
if (instruction->opcode() == HloOpcode::kCopy) {
count++;
}
}
return count;
}
int64 CountCopies(const HloModule& module) {
int64 count = 0;
for (const auto& computation : module.computations()) {
count += CountCopies(*computation);
}
return count;
}
class CpuCopyInsertionTest : public HloTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
ASSERT_IS_OK(copy_insertion.Run(module).status());
}
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
};
TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
// Test a while body and condition which are each simply a constant (root of
// computation is a constant). Each constant should be copied.
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto param_0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
auto body_builder = HloComputation::Builder("body");
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
module->AddEntryComputation(builder.Build());
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 3);
EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant()));
}
TEST_F(CpuCopyInsertionTest, TupleCall) {
// Test a kCall instruction which calls a computation which produces a three
// element tuple: one is a constant, one is a parameter, and one is produced
// in the computation. The constant and parameter should be copied.
auto module = CreateNewVerifiedModule();
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
auto sub_builder = HloComputation::Builder("subcomputation");
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto constant = sub_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
sub_builder.AddInstruction(
HloInstruction::CreateTuple({sub_param, constant, add}));
HloComputation* subcomputation =
module->AddEmbeddedComputation(sub_builder.Build());
builder.AddInstruction(
HloInstruction::CreateCall(tuple_shape, {param}, subcomputation));
module->AddEntryComputation(builder.Build());
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()),
op::Add()));
}
} // namespace
} // namespace xla

View File

@ -54,12 +54,7 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
} }
} }
// The GPU backend needs additional copies added due to deficiencies in return changed;
// buffer assignment.
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
CopyInsertion::AddCopiesForBufferAssignment(module));
return changed || buffer_assignment_changed;
} }
} // namespace gpu } // namespace gpu

View File

@ -51,7 +51,8 @@ void HloToIrBindings::EmitBasePointersForHlos(
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function; absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
auto arg_iter = function->arg_begin(); auto arg_iter = function->arg_begin();
for (const HloInstruction* io_hlo : io_hlos) { for (const HloInstruction* io_hlo : io_hlos) {
CHECK(!absl::c_count(non_io_hlos, io_hlo)) CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
!absl::c_count(non_io_hlos, io_hlo))
<< "IO HLOs and non-IO HLOs should be disjoint"; << "IO HLOs and non-IO HLOs should be disjoint";
if (!already_bound_for_this_function.contains(io_hlo)) { if (!already_bound_for_this_function.contains(io_hlo)) {
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {

View File

@ -177,13 +177,6 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
llvm::Value* source_address) { llvm::Value* source_address) {
CHECK_EQ(2, computation.num_parameters()); CHECK_EQ(2, computation.num_parameters());
if (computation.instruction_count() != 3) {
// We special-case only computations with one computing instruction for now.
// Such computation has exactly three instructions given it has two
// parameters.
return false;
}
HloOpcode root_opcode = computation.root_instruction()->opcode(); HloOpcode root_opcode = computation.root_instruction()->opcode();
PrimitiveType element_type = PrimitiveType element_type =
computation.root_instruction()->shape().element_type(); computation.root_instruction()->shape().element_type();
@ -191,12 +184,11 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
element_type == S64 || element_type == U64; element_type == S64 || element_type == U64;
llvm::Value* source = Load(source_address, "source"); llvm::Value* source = Load(source_address, "source");
// kCopy of RHS -> atomic store. // Just passing along RHS -> atomic store.
if (root_opcode == HloOpcode::kCopy && if (computation.instruction_count() == 2 &&
root_opcode == HloOpcode::kParameter &&
(element_type == F32 || is_atomic_integral) && (element_type == F32 || is_atomic_integral) &&
computation.root_instruction()->operand(0)->opcode() == computation.root_instruction()->parameter_number() == 1) {
HloOpcode::kParameter &&
computation.root_instruction()->operand(0)->parameter_number() == 1) {
llvm::StoreInst* store = Store(source, output_address); llvm::StoreInst* store = Store(source, output_address);
store->setAtomic(llvm::AtomicOrdering::Unordered); store->setAtomic(llvm::AtomicOrdering::Unordered);
// Derive a minimum alignment from the type. The optimizer can increase it // Derive a minimum alignment from the type. The optimizer can increase it
@ -205,6 +197,13 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
return true; return true;
} }
if (computation.instruction_count() != 3) {
// We special-case only computations with one computing instruction for now.
// Such computation has exactly three instructions given it has two
// parameters.
return false;
}
if (root_opcode == HloOpcode::kAdd) { if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types. // NVPTX supports atomicAdd on F32 and integer types.
if (element_type == F32) { if (element_type == F32) {