Lower concatenate operations to memcpy.

This usually ends up being faster than elemental IR implementation.

PiperOrigin-RevId: 163489782
This commit is contained in:
A. Unique TensorFlower 2017-07-28 10:52:57 -07:00 committed by TensorFlower Gardener
parent a553aff131
commit efc63f6248
6 changed files with 252 additions and 2 deletions

View File

@ -2382,6 +2382,173 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
return Status::OK();
}
StatusOr<bool> IrEmitter::EmitFastConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
string* failure_reason) {
if (ShouldEmitParallelLoopFor(*concatenate)) {
*failure_reason =
"cannot generate memcpy-based concat for the parallel CPU backend";
return false;
}
const Shape& output_shape = concatenate->shape();
for (auto* op : operands) {
if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
*failure_reason = "operand has mismatching layouts";
return false;
}
if (LayoutUtil::IsPadded(op->shape())) {
*failure_reason = "operand has padded layout";
return false;
}
}
CHECK(!LayoutUtil::IsPadded(concatenate->shape()));
// We split the dimensions into three categories: the dimension over which we
// are concatenating (concat_dim), the dimensions that are minor to it
// (inner_dims) and the dimensions that are major to it (outer_dims).
int64 concat_dim = concatenate->dimensions(0);
const Layout& output_layout = output_shape.layout();
auto concat_dim_layout_itr =
std::find(output_layout.minor_to_major().begin(),
output_layout.minor_to_major().end(), concat_dim);
std::vector<int64> inner_dims(output_layout.minor_to_major().begin(),
concat_dim_layout_itr);
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
output_layout.minor_to_major().end());
llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
llvm::Type* i8_type = ir_builder_.getInt8Ty();
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(concatenate));
llvm_ir::IrArray target_array(target_address, output_shape);
llvm_ir::ForLoopNest loops(&ir_builder_);
llvm_ir::IrArray::Index outer_dims_index =
loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
std::replace(outer_dims_index.begin(), outer_dims_index.end(),
static_cast<llvm::Value*>(nullptr),
static_cast<llvm::Value*>(ir_builder_.getInt64(0)));
if (!outer_dims.empty()) {
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
}
PrimitiveType primitive_type = output_shape.element_type();
unsigned primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
AddAliasingInformationToIrArray(*concatenate, &target_array);
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
llvm::Value* target_region_begin = ir_builder_.CreateBitCast(
target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
"target_region"),
i8_ptr_type);
int64 byte_offset_into_target_region = 0;
int64 inner_dims_product =
std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
[&](int64 product, int64 inner_dim) {
return product * output_shape.dimensions(inner_dim);
});
// For each operand, emit a memcpy from the operand to the target of size
// equal to the product of inner dimensions.
for (HloInstruction* operand : operands) {
const Shape& input_shape = operand->shape();
llvm_ir::IrArray source_array(GetEmittedValueFor(operand), input_shape);
AddAliasingInformationToIrArray(*operand, &source_array);
llvm::Value* copy_source_address = ir_builder_.CreateBitCast(
source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
"src_addr"),
i8_ptr_type);
llvm::Value* copy_target_address = ir_builder_.CreateGEP(
i8_type, target_region_begin,
ir_builder_.getInt64(byte_offset_into_target_region));
EmitTransferElements(
copy_target_address, copy_source_address,
inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
target_array, source_array);
byte_offset_into_target_region += inner_dims_product *
input_shape.dimensions(concat_dim) *
primitive_type_size;
}
if (!outer_dims.empty()) {
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
}
emitted_value_[concatenate] = target_address;
return true;
}
void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
int64 element_count,
PrimitiveType primitive_type,
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& source_array) {
unsigned primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
unsigned element_alignment = GCD(
primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_));
if (element_count == 1) {
auto* load_instruction = ir_builder_.CreateAlignedLoad(
ir_builder_.CreateBitCast(source, primitive_ptr_type),
element_alignment);
source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
auto* store_instruction = ir_builder_.CreateAlignedStore(
load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type),
element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
auto* memcpy_instruction = ir_builder_.CreateMemCpy(
target, source, element_count * primitive_type_size, element_alignment);
// The memcpy does the load and the store internally. The aliasing related
// metadata has to reflect that.
std::map<int, llvm::MDNode*> merged_metadata =
llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
target_array.metadata());
for (const auto& kind_md_pair : merged_metadata) {
memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
}
}
}
Status IrEmitter::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
string failure_reason;
TF_ASSIGN_OR_RETURN(
bool successful,
EmitFastConcatenate(concatenate, operands, &failure_reason));
if (successful) {
VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
return Status::OK();
}
VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
<< ": " << failure_reason;
return DefaultAction(concatenate);
}
Status IrEmitter::FinishVisit(HloInstruction* root) {
// When this method is called, we should have already emitted an IR value for
// the root (return) op. The IR value holds the address of the buffer holding

View File

@ -191,6 +191,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) override;
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
@ -407,6 +410,21 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
unsigned element_alignment);
// Tries to emit a fast concatenate operation using memcpy. Returns true if
// successful, and false on failure. On failure, sets "failure_reason" to a
// string describing why it could not emit a fast concatenate.
StatusOr<bool> EmitFastConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
string* failure_reason);
// Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
// from the address "source" to the address "target".
void EmitTransferElements(llvm::Value* target, llvm::Value* source,
int64 element_count, PrimitiveType primitive_type,
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& source_array);
// Name of the computation entry function. This function serves as the
// top-level "main" of the computation and will be invoked by the JIT.
string entry_function_name_;

View File

@ -56,7 +56,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
alias_scope_md =
GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain());
}
array->AddAliasScopeMetadata(alias_scope_md);
if (alias_scope_md != nullptr) {
array->AddAliasScopeMetadata(alias_scope_md);
}
}
if (module_.config().debug_options().xla_llvm_enable_noalias_metadata()) {
@ -65,7 +67,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(),
assignment_, hlo);
}
array->AddNoaliasMetadata(noalias_md);
if (noalias_md != nullptr) {
array->AddNoaliasMetadata(noalias_md);
}
}
if (module_.config()

View File

@ -218,17 +218,22 @@ class IrArray {
llvm::IRBuilder<>* ir_builder) const;
void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
CHECK_NE(alias_scope, nullptr);
AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
}
void AddNoaliasMetadata(llvm::MDNode* noalias) {
CHECK_NE(noalias, nullptr);
AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
}
void AddInvariantLoad(llvm::MDNode* invariant_load) {
CHECK_NE(invariant_load, nullptr);
AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load);
}
const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
// Bumps the "which_dimension" value within the provided index by the provided
// addend.
static Index BumpIndex(const Index& index, int64 which_dimension,

View File

@ -464,5 +464,53 @@ void SetTargetOptions(bool fast_math_enabled,
target_options->NoSignedZerosFPMath = fast_math_enabled;
}
std::map<int, llvm::MDNode*> MergeMetadata(
llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
const std::map<int, llvm::MDNode*>& b) {
// We should extend this as needed to deal with other kinds of metadata like
// !dereferenceable and !range.
std::map<int, llvm::MDNode*> result;
for (auto kind_md_pair : a) {
if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
for (const auto& scope_a : kind_md_pair.second->operands()) {
scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
}
auto it = b.find(kind_md_pair.first);
if (it != b.end()) {
for (const auto& scope_b : it->second->operands()) {
if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
}
}
}
result[llvm::LLVMContext::MD_alias_scope] =
llvm::MDNode::get(*context, union_of_scopes);
} else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
for (const auto& scope_a : kind_md_pair.second->operands()) {
scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
}
auto it = b.find(kind_md_pair.first);
if (it != b.end()) {
for (const auto& scope_b : it->second->operands()) {
if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
}
}
}
if (!intersection_of_scopes.empty()) {
result[llvm::LLVMContext::MD_noalias] =
llvm::MDNode::get(*context, intersection_of_scopes);
}
}
}
return result;
}
} // namespace llvm_ir
} // namespace xla

View File

@ -238,6 +238,14 @@ llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled);
void SetTargetOptions(bool fast_math_enabled,
llvm::TargetOptions* target_options);
// Computes a conservative union of the metadata in "a" and "b". For
// aliasing-related metadata, this means the result can be applied to
// instructions whose aliasing relationship can be described either by "a" *or*
// by "b".
std::map<int, llvm::MDNode*> MergeMetadata(
llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
const std::map<int, llvm::MDNode*>& b);
} // namespace llvm_ir
} // namespace xla