Benjamin Kramer 01b38cd7c6 [XLA:CPU] Plumb through a minimal emitter for matmuls using the mlir linalg dialect
This is just the most basic lowering and will generate linalg.matmul for small
matmuls and then convert to loops. The result is fairly slow, but we can
iterate on that.

To make XLA use it set XLA_FLAGS=--xla_backend_extra_options=xla_use_linalg_for_dot

PiperOrigin-RevId: 312471829
Change-Id: I213d1f6114671bc595ac1647d3689736ee8f56f4
2020-05-20 06:43:17 -07:00

133 lines
5.2 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO/Internalize.h"
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
namespace xla {
namespace cpu {
namespace {
// Lower an MLIR module to an LLVM module.
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
mlir::PassManager manager(module->getContext());
manager.addPass(mlir::createConvertLinalgToLoopsPass());
manager.addPass(mlir::createConvertLinalgToLLVMPass());
manager.addPass(mlir::createConvertVectorToLLVMPass());
manager.addPass(mlir::createLowerToLLVMPass());
CHECK(succeeded(manager.run(*module)));
return mlir::translateModuleToLLVMIR(*module);
}
// Get arguments to pass a memref to an mlir function.
void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
llvm::IRBuilder<> *b, const Shape &opShape,
llvm::Value *op_val) {
llvm::Type *ty = op_val->getType();
while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
llvm::cast<llvm::PointerType>(ty)->getElementType())) {
ty = aty->getElementType()->getPointerTo();
}
op_val = b->CreateBitCast(op_val, ty);
args->push_back(op_val); // Allocated pointer.
args->push_back(op_val); // Aligned pointer.
args->push_back(b->getInt64(0)); // Offset.
// Sizes.
for (int64 dim : opShape.dimensions()) {
args->push_back(b->getInt64(dim));
}
int64_t accumulated_stride = 1;
llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
for (int64 dim : LayoutUtil::MinorToMajor(opShape)) {
strides[dim] = accumulated_stride;
accumulated_stride *= opShape.dimensions(dim);
}
// Strides.
for (int64 stride : strides) {
args->push_back(b->getInt64(stride));
}
}
} // namespace
Status EmitMlirFuncAndCall(
mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) {
llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
mlir::Builder mlir_builder(context);
// Get memref types for the inputs and output.
TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
result_shape, mlir_builder));
std::vector<mlir::Type> operand_types = {ret_memref};
for (int i = 0; i != operand_shapes.size(); ++i) {
TF_ASSIGN_OR_RETURN(
mlir::Type op_memref,
ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
operand_types.push_back(op_memref);
}
// Create the function an call the emission callback.
mlir::Location loc = mlir::UnknownLoc::get(context);
auto function = mlir::FuncOp::create(
loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
function.addEntryBlock();
mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
mlir_module->push_back(function);
mlir::OpBuilder op_builder(&function.getBody());
emitter(&op_builder, function);
// Now link it all into the main LLVM module.
auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module));
mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
llvm::Linker::linkModules(
*llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
[](llvm::Module &M, const llvm::StringSet<> &GVS) {
llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
});
// And leave behind a call to the function generated by MLIR.
llvm::Function *func = llvm_module->getFunction(func_name);
llvm::SmallVector<llvm::Value *, 4> op_vals;
BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
for (int i = 0; i != operand_shapes.size(); ++i) {
BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
}
b->CreateCall(func, op_vals);
return Status::OK();
}
} // namespace cpu
} // namespace xla