[XLA:CPU/GPU] Remove extra copy insertion to work around buffer assignment's local view
BufferAssignment is module-scoped now, so the extra copies shouldn't be needed. This means that parameter instructions can be at the root of a computation now, adjust code that relies on that not happening. PiperOrigin-RevId: 252002504
This commit is contained in:
parent
1ad04efcee
commit
bfc8733ffb
@ -1222,57 +1222,4 @@ bool IsWhileBody(const HloComputation* computation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment(
|
|
||||||
HloModule* module) {
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
|
|
||||||
HloDataflowAnalysis::Run(*module));
|
|
||||||
|
|
||||||
bool changed = false;
|
|
||||||
|
|
||||||
// If a buffer live out of a computation is a constant, a parameter, or not
|
|
||||||
// defined in the computation, then copy it to account for the limited
|
|
||||||
// computation-scoped analysis in buffer assignment. An exception to this rule
|
|
||||||
// is the while body which is handled properly without copies.
|
|
||||||
for (HloComputation* computation : module->computations()) {
|
|
||||||
if (computation == module->entry_computation() ||
|
|
||||||
IsWhileBody(computation, *call_graph)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
HloInstruction* root = computation->root_instruction();
|
|
||||||
ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false);
|
|
||||||
bool copy_root = false;
|
|
||||||
for (const auto& pair : dataflow->GetInstructionValueSet(root)) {
|
|
||||||
const ShapeIndex& index = pair.first;
|
|
||||||
const HloValueSet& value_set = pair.second;
|
|
||||||
for (const HloValue* value : value_set.values()) {
|
|
||||||
HloInstruction* def = value->defining_instruction();
|
|
||||||
if (def->parent() != computation ||
|
|
||||||
def->opcode() == HloOpcode::kConstant ||
|
|
||||||
def->opcode() == HloOpcode::kParameter) {
|
|
||||||
*indices_to_copy.mutable_element(index) = true;
|
|
||||||
copy_root = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (copy_root) {
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
HloInstruction * root_copy,
|
|
||||||
computation->DeepCopyInstruction(root, &indices_to_copy));
|
|
||||||
computation->set_root_instruction(root_copy);
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TupleSimplifier tuple_simplifier;
|
|
||||||
HloDCE dce;
|
|
||||||
TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
|
|
||||||
tuple_simplifier.Run(module));
|
|
||||||
TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module));
|
|
||||||
|
|
||||||
return changed || tuple_simplifier_changed || dce_changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -60,17 +60,6 @@ class CopyInsertion : public HloModulePass {
|
|||||||
// (copies were inserted).
|
// (copies were inserted).
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
|
||||||
// The CPU and GPU backend need additional copies added due to deficiencies in
|
|
||||||
// buffer assignment. Specifically, copies are needed for constants live-out
|
|
||||||
// of computations, and for values which are live-in and live-out of the same
|
|
||||||
// computation. These copies are needed because buffer-assignment uses a
|
|
||||||
// computation-scoped analyis (TuplePointsToAnalysis) and has limited
|
|
||||||
// visibility across computation boundaries. This method adds these necessary
|
|
||||||
// copies. Returns whether the module was modified.
|
|
||||||
//
|
|
||||||
// TODO(b/62548313): Remove this when buffer assignment is module-scoped.
|
|
||||||
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
|
|
||||||
|
|
||||||
// Try to remove as many copies from the module as possible without
|
// Try to remove as many copies from the module as possible without
|
||||||
// introducing live range interference. Only copy instructions that are
|
// introducing live range interference. Only copy instructions that are
|
||||||
// eligible for copy elision are considered for removal.
|
// eligible for copy elision are considered for removal.
|
||||||
|
@ -76,7 +76,6 @@ cc_library(
|
|||||||
":compiler_functor",
|
":compiler_functor",
|
||||||
":buffer_info_util",
|
":buffer_info_util",
|
||||||
":conv_canonicalization",
|
":conv_canonicalization",
|
||||||
":cpu_copy_insertion",
|
|
||||||
":cpu_executable",
|
":cpu_executable",
|
||||||
":cpu_hlo_support_checker",
|
":cpu_hlo_support_checker",
|
||||||
":cpu_instruction_fusion",
|
":cpu_instruction_fusion",
|
||||||
@ -92,6 +91,7 @@ cc_library(
|
|||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
":target_machine_features",
|
":target_machine_features",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
"//tensorflow/compiler/xla/service:copy_insertion",
|
||||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||||
"//tensorflow/compiler/xla/service:dump",
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
"//tensorflow/compiler/xla/service:map_inliner",
|
"//tensorflow/compiler/xla/service:map_inliner",
|
||||||
@ -956,18 +956,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "cpu_copy_insertion",
|
|
||||||
srcs = ["cpu_copy_insertion.cc"],
|
|
||||||
hdrs = ["cpu_copy_insertion.h"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/xla/service:copy_insertion",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "vector_support_library",
|
name = "vector_support_library",
|
||||||
srcs = ["vector_support_library.cc"],
|
srcs = ["vector_support_library.cc"],
|
||||||
@ -986,26 +974,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "cpu_copy_insertion_test",
|
|
||||||
srcs = ["cpu_copy_insertion_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":cpu_copy_insertion",
|
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:test",
|
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpu_hlo_support_checker",
|
name = "cpu_hlo_support_checker",
|
||||||
srcs = ["cpu_hlo_support_checker.cc"],
|
srcs = ["cpu_hlo_support_checker.cc"],
|
||||||
|
@ -56,10 +56,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
|
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
|
||||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
|
||||||
@ -402,7 +402,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
|
|||||||
// interfering with the rewrites.
|
// interfering with the rewrites.
|
||||||
pipeline.AddPass<HloDCE>();
|
pipeline.AddPass<HloDCE>();
|
||||||
pipeline.AddPass<FlattenCallGraph>();
|
pipeline.AddPass<FlattenCallGraph>();
|
||||||
pipeline.AddPass<CpuCopyInsertion>();
|
pipeline.AddPass<CopyInsertion>();
|
||||||
pipeline.AddPass<HloDCE>();
|
pipeline.AddPass<HloDCE>();
|
||||||
return pipeline.Run(module).status();
|
return pipeline.Run(module).status();
|
||||||
}
|
}
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
|
|
||||||
StatusOr<bool> CpuCopyInsertion::Run(HloModule* module) {
|
|
||||||
CopyInsertion generic_copy_insertion;
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module));
|
|
||||||
|
|
||||||
// The CPU backend needs additional copies added due to deficiencies in
|
|
||||||
// buffer assignment.
|
|
||||||
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
|
|
||||||
CopyInsertion::AddCopiesForBufferAssignment(module));
|
|
||||||
|
|
||||||
return generic_changed || buffer_assignment_changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
|
@ -1,42 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
|
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
|
|
||||||
// Besides the modifications made by the generic xla::CopyInsertion, this
|
|
||||||
// CPU-specific copy insertion pass also adds copies to values live out of
|
|
||||||
// computations satisfying certain conditions (defined by constant or parameter,
|
|
||||||
// etc). This is necessary because of deficiencies of buffer
|
|
||||||
// assignment. Specifically, buffer assignment is computation-scoped and does
|
|
||||||
// not recognized aliasing between arguments and outputs of computations.
|
|
||||||
//
|
|
||||||
// TODO(b/62548313): Remove this when buffer assignment is smarter
|
|
||||||
// (module-scoped).
|
|
||||||
class CpuCopyInsertion : public HloModulePass {
|
|
||||||
public:
|
|
||||||
absl::string_view name() const override { return "copy-insertion"; }
|
|
||||||
|
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace xla
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
|
|
@ -1,139 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
namespace op = xla::testing::opcode_matchers;
|
|
||||||
|
|
||||||
int64 CountCopies(const HloComputation& computation) {
|
|
||||||
int64 count = 0;
|
|
||||||
for (const auto& instruction : computation.instructions()) {
|
|
||||||
if (instruction->opcode() == HloOpcode::kCopy) {
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 CountCopies(const HloModule& module) {
|
|
||||||
int64 count = 0;
|
|
||||||
for (const auto& computation : module.computations()) {
|
|
||||||
count += CountCopies(*computation);
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
class CpuCopyInsertionTest : public HloTestBase {
|
|
||||||
protected:
|
|
||||||
void InsertCopies(HloModule* module) {
|
|
||||||
CpuCopyInsertion copy_insertion;
|
|
||||||
ASSERT_IS_OK(copy_insertion.Run(module).status());
|
|
||||||
}
|
|
||||||
|
|
||||||
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
|
|
||||||
// Test a while body and condition which are each simply a constant (root of
|
|
||||||
// computation is a constant). Each constant should be copied.
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param_0 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
|
|
||||||
|
|
||||||
auto body_builder = HloComputation::Builder("body");
|
|
||||||
body_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
|
||||||
body_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
|
|
||||||
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
|
|
||||||
|
|
||||||
auto cond_builder = HloComputation::Builder("condition");
|
|
||||||
cond_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
|
||||||
cond_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
|
||||||
HloComputation* condition =
|
|
||||||
module->AddEmbeddedComputation(cond_builder.Build());
|
|
||||||
|
|
||||||
auto xla_while = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
|
|
||||||
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
InsertCopies(module.get());
|
|
||||||
|
|
||||||
EXPECT_EQ(CountCopies(*module), 3);
|
|
||||||
|
|
||||||
EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
|
|
||||||
EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
|
|
||||||
EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CpuCopyInsertionTest, TupleCall) {
|
|
||||||
// Test a kCall instruction which calls a computation which produces a three
|
|
||||||
// element tuple: one is a constant, one is a parameter, and one is produced
|
|
||||||
// in the computation. The constant and parameter should be copied.
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
|
|
||||||
const Shape tuple_shape =
|
|
||||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
|
|
||||||
|
|
||||||
auto sub_builder = HloComputation::Builder("subcomputation");
|
|
||||||
auto sub_param = sub_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
|
|
||||||
auto constant = sub_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
|
|
||||||
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
|
|
||||||
sub_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateTuple({sub_param, constant, add}));
|
|
||||||
HloComputation* subcomputation =
|
|
||||||
module->AddEmbeddedComputation(sub_builder.Build());
|
|
||||||
|
|
||||||
builder.AddInstruction(
|
|
||||||
HloInstruction::CreateCall(tuple_shape, {param}, subcomputation));
|
|
||||||
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
InsertCopies(module.get());
|
|
||||||
|
|
||||||
EXPECT_EQ(CountCopies(*subcomputation), 2);
|
|
||||||
EXPECT_THAT(subcomputation->root_instruction(),
|
|
||||||
op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()),
|
|
||||||
op::Add()));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace xla
|
|
@ -54,12 +54,7 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The GPU backend needs additional copies added due to deficiencies in
|
return changed;
|
||||||
// buffer assignment.
|
|
||||||
TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed,
|
|
||||||
CopyInsertion::AddCopiesForBufferAssignment(module));
|
|
||||||
|
|
||||||
return changed || buffer_assignment_changed;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -51,7 +51,8 @@ void HloToIrBindings::EmitBasePointersForHlos(
|
|||||||
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
|
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
|
||||||
auto arg_iter = function->arg_begin();
|
auto arg_iter = function->arg_begin();
|
||||||
for (const HloInstruction* io_hlo : io_hlos) {
|
for (const HloInstruction* io_hlo : io_hlos) {
|
||||||
CHECK(!absl::c_count(non_io_hlos, io_hlo))
|
CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
|
||||||
|
!absl::c_count(non_io_hlos, io_hlo))
|
||||||
<< "IO HLOs and non-IO HLOs should be disjoint";
|
<< "IO HLOs and non-IO HLOs should be disjoint";
|
||||||
if (!already_bound_for_this_function.contains(io_hlo)) {
|
if (!already_bound_for_this_function.contains(io_hlo)) {
|
||||||
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
|
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
|
||||||
|
@ -177,13 +177,6 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
|||||||
llvm::Value* source_address) {
|
llvm::Value* source_address) {
|
||||||
CHECK_EQ(2, computation.num_parameters());
|
CHECK_EQ(2, computation.num_parameters());
|
||||||
|
|
||||||
if (computation.instruction_count() != 3) {
|
|
||||||
// We special-case only computations with one computing instruction for now.
|
|
||||||
// Such computation has exactly three instructions given it has two
|
|
||||||
// parameters.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
HloOpcode root_opcode = computation.root_instruction()->opcode();
|
HloOpcode root_opcode = computation.root_instruction()->opcode();
|
||||||
PrimitiveType element_type =
|
PrimitiveType element_type =
|
||||||
computation.root_instruction()->shape().element_type();
|
computation.root_instruction()->shape().element_type();
|
||||||
@ -191,12 +184,11 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
|||||||
element_type == S64 || element_type == U64;
|
element_type == S64 || element_type == U64;
|
||||||
llvm::Value* source = Load(source_address, "source");
|
llvm::Value* source = Load(source_address, "source");
|
||||||
|
|
||||||
// kCopy of RHS -> atomic store.
|
// Just passing along RHS -> atomic store.
|
||||||
if (root_opcode == HloOpcode::kCopy &&
|
if (computation.instruction_count() == 2 &&
|
||||||
|
root_opcode == HloOpcode::kParameter &&
|
||||||
(element_type == F32 || is_atomic_integral) &&
|
(element_type == F32 || is_atomic_integral) &&
|
||||||
computation.root_instruction()->operand(0)->opcode() ==
|
computation.root_instruction()->parameter_number() == 1) {
|
||||||
HloOpcode::kParameter &&
|
|
||||||
computation.root_instruction()->operand(0)->parameter_number() == 1) {
|
|
||||||
llvm::StoreInst* store = Store(source, output_address);
|
llvm::StoreInst* store = Store(source, output_address);
|
||||||
store->setAtomic(llvm::AtomicOrdering::Unordered);
|
store->setAtomic(llvm::AtomicOrdering::Unordered);
|
||||||
// Derive a minimum alignment from the type. The optimizer can increase it
|
// Derive a minimum alignment from the type. The optimizer can increase it
|
||||||
@ -205,6 +197,13 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (computation.instruction_count() != 3) {
|
||||||
|
// We special-case only computations with one computing instruction for now.
|
||||||
|
// Such computation has exactly three instructions given it has two
|
||||||
|
// parameters.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (root_opcode == HloOpcode::kAdd) {
|
if (root_opcode == HloOpcode::kAdd) {
|
||||||
// NVPTX supports atomicAdd on F32 and integer types.
|
// NVPTX supports atomicAdd on F32 and integer types.
|
||||||
if (element_type == F32) {
|
if (element_type == F32) {
|
||||||
|
Loading…
Reference in New Issue
Block a user