Use constant buffer allocations for XLA:CPU

This is simpler than the corresponding change to XLA:GPU because on XLA:CPU all
instructions are codegened so we can always embed a pointer to the constant
global variable directly in the generated LLVM IR.

PiperOrigin-RevId: 206363887
This commit is contained in:
Sanjoy Das 2018-07-27 13:24:46 -07:00 committed by TensorFlower Gardener
parent 90fe37ab8d
commit 388d0d8601
16 changed files with 180 additions and 78 deletions

View File

@ -252,6 +252,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",

View File

@ -562,7 +562,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
BufferAssigner::Run(
module.get(),
xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
BufferSizeBytesFunction(), memory_alignment));
BufferSizeBytesFunction(), memory_alignment,
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@ -584,6 +586,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::move(computation_to_profile_idx),
&target_machine_features);
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
for (auto embedded_computation :
entry_computation->MakeEmbeddedComputationsList()) {
if (embedded_computation->IsFusionComputation()) {
@ -747,7 +751,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferAssigner::Run(
module,
xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
BufferSizeBytesFunction(), memory_alignment));
BufferSizeBytesFunction(), memory_alignment,
/*allow_input_output_aliasing=*/false,
/*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@ -776,6 +782,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
&target_machine_features);
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
@ -832,7 +841,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
// Callers don't need to allocate temporary buffers for parameters.
if (allocation.is_entry_computation_parameter()) {
if (allocation.is_entry_computation_parameter() ||
allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}

View File

@ -88,6 +88,11 @@ Status CpuExecutable::AllocateBuffers(
continue;
}
if (allocation.is_constant()) {
VLOG(3) << "allocation #" << i << " is a constant";
continue;
}
if (allocation.is_thread_local()) {
VLOG(3) << "buffer #" << i << " is thread-local";
continue;

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
@ -175,23 +176,34 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
result_global, IrShapeType(literal.shape())->getPointerTo());
}
Status IrEmitter::EmitConstantGlobals() {
for (const BufferAllocation& allocation : assignment_.Allocations()) {
if (!allocation.is_constant()) {
continue;
}
const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
llvm::Constant* global_for_const;
auto it = emitted_literals_.find(&literal);
if (it != emitted_literals_.end()) {
global_for_const = it->second;
} else {
global_for_const = EmitGlobalForLiteral(literal);
InsertOrDie(&emitted_literals_, &literal, global_for_const);
}
InsertOrDie(&constant_buffer_to_global_, allocation.index(),
global_for_const);
}
return Status::OK();
}
Status IrEmitter::HandleConstant(HloInstruction* constant) {
VLOG(2) << "HandleConstant: " << constant->ToString();
const Literal& literal = constant->literal();
llvm::Constant* global_for_const;
auto it = emitted_literals_.find(&literal);
if (it != emitted_literals_.end()) {
global_for_const = it->second;
} else {
global_for_const = EmitGlobalForLiteral(literal);
emitted_literals_[&literal] = global_for_const;
}
emitted_value_[constant] = global_for_const;
VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const);
VLOG(2) << " its type: "
<< llvm_ir::DumpToString(*global_for_const->getType());
return Status::OK();
// IrEmitter::EmitConstantGlobals has already taken care of emitting the body
// of the constant.
return EmitTargetAddressForOp(constant);
}
Status IrEmitter::HandleCopy(HloInstruction* copy) {
@ -2712,6 +2724,10 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
}
if (allocation.is_constant()) {
return FindOrDie(constant_buffer_to_global_, allocation.index());
}
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);

View File

@ -105,6 +105,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
PrimitiveType return_type, HloComputation* computation,
const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
protected:
//
// The following methods implement the DfsHloVisitor interface.
@ -560,6 +563,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
emitted_literals_;
tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
constant_buffer_to_global_;
TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
};

View File

@ -120,6 +120,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
@ -165,6 +166,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
@ -323,9 +325,9 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",

View File

@ -173,45 +173,6 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index,
buffers_[buffer_index] = buffer;
}
static const HloInstruction& InstrForConstantBufferAllocation(
const BufferAllocation& allocation) {
CHECK(allocation.is_constant());
HloInstruction* const_instr = nullptr;
for (const auto& buffer_offset_pair : allocation.assigned_buffers()) {
const LogicalBuffer* buffer = buffer_offset_pair.first;
// BufferAssignment may have assigned non-constant instructions to this
// allocation too so we can't CHECK this condition. E.g. for
//
// while(init = constant, body = identity, cond = ...)
//
// the LogicalBuffer for the kWhile instruction will have the same
// BufferAllocation as the LogicalBuffer for the (init) constant.
if (buffer->instruction()->opcode() == HloOpcode::kConstant) {
CHECK_EQ(const_instr, nullptr)
<< const_instr->ToString() << " " << buffer->ToString();
const_instr = buffer->instruction();
}
}
CHECK_NE(const_instr, nullptr);
return *const_instr;
}
string ConstantBufferAllocationToGlobalName(
const BufferAllocation& allocation) {
string instr_name = InstrForConstantBufferAllocation(allocation).name();
for (char& c : instr_name) {
if (c == '.') {
c = '_';
}
}
return tensorflow::strings::StrCat("buffer_for_", instr_name);
}
const Literal& LiteralForConstantAllocation(
const BufferAllocation& allocation) {
return InstrForConstantBufferAllocation(allocation).literal();
}
bool ShouldEmitLiteralInLlvmIr(const Literal& literal) {
// LLVM can sometimes do interesting optimizations using scalar constants.
return ShapeUtil::IsScalar(literal.shape());

View File

@ -107,15 +107,6 @@ class BufferAllocations {
bool torn_down_ = false;
};
// In XLA:GPU we map constant buffer allocations to globals in the generated
// LLVM IR. This function gives us the name of the global variable a constant
// buffer is mapped to.
string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation);
// Return the Literal corresponding to `allocation`, which must be a constant
// allocation.
const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation);
// LLVM and PTXAS don't deal well with large constants, so we only emit very
// small constants directly in LLVM IR. Larger constants are emitted with zero
// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -206,13 +207,15 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase global,
executor->GetUntypedSymbol(
ConstantBufferAllocationToGlobalName(allocation), module_handle));
llvm_ir::ConstantBufferAllocationToGlobalName(allocation),
module_handle));
VLOG(3) << "Resolved global "
<< ConstantBufferAllocationToGlobalName(allocation) << " to "
<< global.opaque();
<< llvm_ir::ConstantBufferAllocationToGlobalName(allocation)
<< " to " << global.opaque();
InsertOrDie(&globals, i, global);
const Literal& literal = LiteralForConstantAllocation(allocation);
const Literal& literal =
llvm_ir::LiteralForConstantAllocation(allocation);
CHECK(ShapeUtil::IsArray(literal.shape()));
if (!ShouldEmitLiteralInLlvmIr(literal)) {
VLOG(3) << "H2D memcpy for constant with shape "

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -114,7 +115,8 @@ void HloToIrBindings::EmitBasePointersForHlos(
} else if (slice.allocation()->is_constant()) {
llvm::Value* global_for_constant =
module_->getGlobalVariable(llvm_ir::AsStringRef(
ConstantBufferAllocationToGlobalName(*slice.allocation())));
llvm_ir::ConstantBufferAllocationToGlobalName(
*slice.allocation())));
BindHloToIrValue(*non_io_hlo, global_for_constant);
} else {
const int64 offset = slice.offset();

View File

@ -60,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
@ -2411,8 +2412,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::Value* loc;
if (slice.allocation()->is_constant()) {
loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
llvm_ir::AsStringRef(
ConstantBufferAllocationToGlobalName(*slice.allocation())));
llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName(
*slice.allocation())));
CHECK_NE(loc, nullptr);
} else {
loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
@ -3428,7 +3429,7 @@ Status IrEmitterUnnested::EmitConstantGlobals() {
continue;
}
const Literal& literal = LiteralForConstantAllocation(allocation);
const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
llvm::ArrayType* global_type =
llvm::ArrayType::get(b_.getInt8Ty(), allocation.size());
@ -3453,7 +3454,8 @@ Status IrEmitterUnnested::EmitConstantGlobals() {
global_type, /*isConstant=*/should_emit_initializer,
llvm::GlobalValue::ExternalLinkage,
/*Initializer=*/initializer,
llvm_ir::AsStringRef(ConstantBufferAllocationToGlobalName(allocation)));
llvm_ir::AsStringRef(
llvm_ir::ConstantBufferAllocationToGlobalName(allocation)));
global_for_const->setAlignment(kConstantBufferAlignBytes);
ir_emitter_context_->llvm_module()->getGlobalList().push_back(
global_for_const);

View File

@ -224,6 +224,15 @@ cc_library(
],
)
cc_library(
name = "buffer_assignment_util",
srcs = ["buffer_assignment_util.cc"],
hdrs = ["buffer_assignment_util.h"],
deps = [
"//tensorflow/compiler/xla/service:buffer_assignment",
],
)
cc_library(
name = "math_ops",
srcs = ["math_ops.cc"],

View File

@ -58,7 +58,7 @@ ENTRY while3 {
CompileAndVerifyIr(hlo_string, R"(
; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:.*]]
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0

View File

@ -0,0 +1,59 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
namespace xla {
namespace llvm_ir {
static const HloInstruction& InstrForConstantBufferAllocation(
const BufferAllocation& allocation) {
CHECK(allocation.is_constant());
HloInstruction* const_instr = nullptr;
for (const auto& buffer_offset_pair : allocation.assigned_buffers()) {
const LogicalBuffer* buffer = buffer_offset_pair.first;
// BufferAssignment may have assigned non-constant instructions to this
// allocation too so we can't CHECK this condition. E.g. for
//
// while(init = constant, body = identity, cond = ...)
//
// the LogicalBuffer for the kWhile instruction will have the same
// BufferAllocation as the LogicalBuffer for the (init) constant.
if (buffer->instruction()->opcode() == HloOpcode::kConstant) {
CHECK_EQ(const_instr, nullptr)
<< const_instr->ToString() << " " << buffer->ToString();
const_instr = buffer->instruction();
}
}
CHECK_NE(const_instr, nullptr);
return *const_instr;
}
string ConstantBufferAllocationToGlobalName(
const BufferAllocation& allocation) {
string instr_name = InstrForConstantBufferAllocation(allocation).name();
for (char& c : instr_name) {
if (c == '.') {
c = '_';
}
}
return tensorflow::strings::StrCat("buffer_for_", instr_name);
}
const Literal& LiteralForConstantAllocation(
const BufferAllocation& allocation) {
return InstrForConstantBufferAllocation(allocation).literal();
}
} // namespace llvm_ir
} // namespace xla

View File

@ -0,0 +1,34 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
namespace xla {
namespace llvm_ir {
// In XLA:GPU we map constant buffer allocations to globals in the generated
// LLVM IR. This function gives us the name of the global variable a constant
// buffer is mapped to. Not used on XLA:CPU.
string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation);
// Returns the Literal corresponding to `allocation`, which must be a constant
// allocation.
const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation);
} // namespace llvm_ir
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_

View File

@ -92,9 +92,10 @@ int main(int argc, char** argv) {
// It's lame to hard-code the buffer assignments, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
CHECK_EQ(result->buffer_sizes().size(), 2);
CHECK_EQ(result->buffer_sizes().size(), 3);
CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer
if (triple.isOSBinFormatELF()) {
// Check the ELF magic.
CHECK_EQ(result->object_file_data()[0], 0x7F);