diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 644bf5fd745..8671452e7f8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2382,6 +2382,173 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return Status::OK(); } +StatusOr IrEmitter::EmitFastConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands, + string* failure_reason) { + if (ShouldEmitParallelLoopFor(*concatenate)) { + *failure_reason = + "cannot generate memcpy-based concat for the parallel CPU backend"; + return false; + } + + const Shape& output_shape = concatenate->shape(); + for (auto* op : operands) { + if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) { + *failure_reason = "operand has mismatching layouts"; + return false; + } + if (LayoutUtil::IsPadded(op->shape())) { + *failure_reason = "operand has padded layout"; + return false; + } + } + + CHECK(!LayoutUtil::IsPadded(concatenate->shape())); + + // We split the dimensions into three categories: the dimension over which we + // are concatenating (concat_dim), the dimensions that are minor to it + // (inner_dims) and the dimensions that are major to it (outer_dims). + + int64 concat_dim = concatenate->dimensions(0); + const Layout& output_layout = output_shape.layout(); + auto concat_dim_layout_itr = + std::find(output_layout.minor_to_major().begin(), + output_layout.minor_to_major().end(), concat_dim); + + std::vector inner_dims(output_layout.minor_to_major().begin(), + concat_dim_layout_itr); + std::vector outer_dims(std::next(concat_dim_layout_itr), + output_layout.minor_to_major().end()); + + llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); + llvm::Type* i8_type = ir_builder_.getInt8Ty(); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(concatenate)); + + llvm_ir::IrArray target_array(target_address, output_shape); + + llvm_ir::ForLoopNest loops(&ir_builder_); + llvm_ir::IrArray::Index outer_dims_index = + loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); + std::replace(outer_dims_index.begin(), outer_dims_index.end(), + static_cast(nullptr), + static_cast(ir_builder_.getInt64(0))); + + if (!outer_dims.empty()) { + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + } + + PrimitiveType primitive_type = output_shape.element_type(); + unsigned primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + + AddAliasingInformationToIrArray(*concatenate, &target_array); + + // Contiguous subregions from each operand to the concatenate contribute to a + // contiguous subregion in the target buffer starting at target_region_begin. + llvm::Value* target_region_begin = ir_builder_.CreateBitCast( + target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, + "target_region"), + i8_ptr_type); + int64 byte_offset_into_target_region = 0; + + int64 inner_dims_product = + std::accumulate(inner_dims.begin(), inner_dims.end(), 1l, + [&](int64 product, int64 inner_dim) { + return product * output_shape.dimensions(inner_dim); + }); + + // For each operand, emit a memcpy from the operand to the target of size + // equal to the product of inner dimensions. + for (HloInstruction* operand : operands) { + const Shape& input_shape = operand->shape(); + llvm_ir::IrArray source_array(GetEmittedValueFor(operand), input_shape); + AddAliasingInformationToIrArray(*operand, &source_array); + + llvm::Value* copy_source_address = ir_builder_.CreateBitCast( + source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, + "src_addr"), + i8_ptr_type); + + llvm::Value* copy_target_address = ir_builder_.CreateGEP( + i8_type, target_region_begin, + ir_builder_.getInt64(byte_offset_into_target_region)); + + EmitTransferElements( + copy_target_address, copy_source_address, + inner_dims_product * input_shape.dimensions(concat_dim), primitive_type, + target_array, source_array); + + byte_offset_into_target_region += inner_dims_product * + input_shape.dimensions(concat_dim) * + primitive_type_size; + } + + if (!outer_dims.empty()) { + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + } + + emitted_value_[concatenate] = target_address; + + return true; +} + +void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, + int64 element_count, + PrimitiveType primitive_type, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& source_array) { + unsigned primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + unsigned element_alignment = GCD( + primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); + llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( + llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_)); + + if (element_count == 1) { + auto* load_instruction = ir_builder_.CreateAlignedLoad( + ir_builder_.CreateBitCast(source, primitive_ptr_type), + element_alignment); + source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); + auto* store_instruction = ir_builder_.CreateAlignedStore( + load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type), + element_alignment); + target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); + } else { + auto* memcpy_instruction = ir_builder_.CreateMemCpy( + target, source, element_count * primitive_type_size, element_alignment); + + // The memcpy does the load and the store internally. The aliasing related + // metadata has to reflect that. + std::map merged_metadata = + llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(), + target_array.metadata()); + for (const auto& kind_md_pair : merged_metadata) { + memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); + } + } +} + +Status IrEmitter::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + string failure_reason; + TF_ASSIGN_OR_RETURN( + bool successful, + EmitFastConcatenate(concatenate, operands, &failure_reason)); + if (successful) { + VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString(); + return Status::OK(); + } + + VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString() + << ": " << failure_reason; + + return DefaultAction(concatenate); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 45332536808..2fea6846d88 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -191,6 +191,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -407,6 +410,21 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, unsigned element_alignment); + // Tries to emit a fast concatenate operation using memcpy. Returns true if + // successful, and false on failure. On failure, sets "failure_reason" to a + // string describing why it could not emit a fast concatenate. + StatusOr EmitFastConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands, + string* failure_reason); + + // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" + // from the address "source" to the address "target". + void EmitTransferElements(llvm::Value* target, llvm::Value* source, + int64 element_count, PrimitiveType primitive_type, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& source_array); + // Name of the computation entry function. This function serves as the // top-level "main" of the computation and will be invoked by the JIT. string entry_function_name_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 6bfe8bfc756..5e28e37600c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -56,7 +56,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, alias_scope_md = GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain()); } - array->AddAliasScopeMetadata(alias_scope_md); + if (alias_scope_md != nullptr) { + array->AddAliasScopeMetadata(alias_scope_md); + } } if (module_.config().debug_options().xla_llvm_enable_noalias_metadata()) { @@ -65,7 +67,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), assignment_, hlo); } - array->AddNoaliasMetadata(noalias_md); + if (noalias_md != nullptr) { + array->AddNoaliasMetadata(noalias_md); + } } if (module_.config() diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index a38cf0e5d9d..a6a3ea1adc4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -218,17 +218,22 @@ class IrArray { llvm::IRBuilder<>* ir_builder) const; void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { + CHECK_NE(alias_scope, nullptr); AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); } void AddNoaliasMetadata(llvm::MDNode* noalias) { + CHECK_NE(noalias, nullptr); AddMetadata(llvm::LLVMContext::MD_noalias, noalias); } void AddInvariantLoad(llvm::MDNode* invariant_load) { + CHECK_NE(invariant_load, nullptr); AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load); } + const std::map& metadata() const { return metadata_; } + // Bumps the "which_dimension" value within the provided index by the provided // addend. static Index BumpIndex(const Index& index, int64 which_dimension, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 6d985fba0cb..0ae75c5b3c6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -464,5 +464,53 @@ void SetTargetOptions(bool fast_math_enabled, target_options->NoSignedZerosFPMath = fast_math_enabled; } +std::map MergeMetadata( + llvm::LLVMContext* context, const std::map& a, + const std::map& b) { + // We should extend this as needed to deal with other kinds of metadata like + // !dereferenceable and !range. + + std::map result; + for (auto kind_md_pair : a) { + if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) { + llvm::SmallVector union_of_scopes; + llvm::SmallPtrSet scope_set; + for (const auto& scope_a : kind_md_pair.second->operands()) { + scope_set.insert(llvm::cast(scope_a.get())); + union_of_scopes.push_back(llvm::cast(scope_a.get())); + } + auto it = b.find(kind_md_pair.first); + if (it != b.end()) { + for (const auto& scope_b : it->second->operands()) { + if (!scope_set.count(llvm::cast(scope_b.get()))) { + union_of_scopes.push_back(llvm::cast(scope_b.get())); + } + } + } + result[llvm::LLVMContext::MD_alias_scope] = + llvm::MDNode::get(*context, union_of_scopes); + } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) { + llvm::SmallVector intersection_of_scopes; + llvm::SmallPtrSet scope_set; + for (const auto& scope_a : kind_md_pair.second->operands()) { + scope_set.insert(llvm::cast(scope_a.get())); + } + auto it = b.find(kind_md_pair.first); + if (it != b.end()) { + for (const auto& scope_b : it->second->operands()) { + if (scope_set.count(llvm::cast(scope_b))) { + intersection_of_scopes.push_back(llvm::cast(scope_b)); + } + } + } + if (!intersection_of_scopes.empty()) { + result[llvm::LLVMContext::MD_noalias] = + llvm::MDNode::get(*context, intersection_of_scopes); + } + } + } + return result; +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 96d2c2dba8b..6d94603338c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -238,6 +238,14 @@ llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled); void SetTargetOptions(bool fast_math_enabled, llvm::TargetOptions* target_options); +// Computes a conservative union of the metadata in "a" and "b". For +// aliasing-related metadata, this means the result can be applied to +// instructions whose aliasing relationship can be described either by "a" *or* +// by "b". +std::map MergeMetadata( + llvm::LLVMContext* context, const std::map& a, + const std::map& b); + } // namespace llvm_ir } // namespace xla