PiperOrigin-RevId: 312775865
Change-Id: Iee2170660e6b2cd0a81695e8843bebfb311c480b
This commit is contained in:
A. Unique TensorFlower 2020-05-21 18:03:31 -07:00 committed by TensorFlower Gardener
parent c7229fcabb
commit 8fc976574e
4 changed files with 70 additions and 13 deletions

View File

@ -444,6 +444,10 @@ def TerminatorOp :
let description = [{
Terminator operation for the LHLO dialect.
}];
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, ValueRange operands",
[{ build(b, result, llvm::None, operands, llvm::None); }]
>];
}
#endif // LHLO_OPS

View File

@ -43,8 +43,8 @@ constexpr StringRef kTempBufferAttr = "temp";
template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter =
NonVoidToVoidReturnOpConverter<mlir::ReturnOp, xla_lhlo::TerminatorOp,
xla_lhlo::CopyOp>;
NoBufferOperandsReturnOpConverter<mlir::ReturnOp, xla_lhlo::TerminatorOp,
xla_lhlo::CopyOp>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,

View File

@ -57,8 +57,9 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
for (auto func_arg : func.getArguments()) {
func_args.insert(func_arg);
}
MLIRContext* ctx = func.getContext();
OpBuilder b(func);
OperationFolder folder(func.getContext());
OperationFolder folder(ctx);
func.walk([&](linalg::GenericOp generic_op) {
SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
tile_sizes_.end());
@ -68,12 +69,14 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
auto op = cast<LinalgOp>(generic_op.getOperation());
for (const Value result : op.getOutputBuffers()) {
if (!func_args.count(result)) continue;
if (tileGenericOp(op, tile_sizes, &b, &folder)) {
if (tileGenericOp(op, tile_sizes, &b)) {
generic_op.erase();
return;
}
}
});
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
applyPatternsAndFoldGreedily(func, patterns);
// Fuse producers of tiled linalg ops.
llvm::SmallDenseSet<Operation*> erase_set;
@ -92,19 +95,22 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
}
}
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}
for (auto* e : erase_set) e->erase();
}
private:
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b,
OperationFolder* folder) {
auto tiled_generic_op =
use_parallel_loops_
? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes,
/*permutation=*/{}, folder)
: linalg::tileLinalgOp(*b, op, tile_sizes,
/*permutation=*/{}, folder);
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
auto loopType = use_parallel_loops_
? linalg::LinalgTilingLoopType::ParallelLoops
: linalg::LinalgTilingLoopType::Loops;
auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
linalg::LinalgTilingOptions()
.setTileSizes(tile_sizes)
.setLoopType(loopType));
return tiled_generic_op.hasValue();
}

View File

@ -175,6 +175,7 @@ filegroup(
filegroup(
name = "AffineOpsTdFiles",
srcs = [
"include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td",
"include/mlir/Dialect/Affine/IR/AffineOps.td",
"include/mlir/Dialect/Affine/IR/AffineOpsBase.td",
"include/mlir/Interfaces/LoopLikeInterface.td",
@ -207,6 +208,26 @@ gentbl(
],
)
gentbl(
name = "AffineMemoryOpInterfacesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
"-gen-op-interface-decls",
"include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc",
),
(
"-gen-op-interface-defs",
"include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td",
td_srcs = [
":AffineOpsTdFiles",
],
)
##---------------------------------------------------------------------------##
# AVX512 dialect.
##---------------------------------------------------------------------------##
@ -462,6 +483,7 @@ cc_library(
]),
includes = ["include"],
deps = [
":AffineMemoryOpInterfacesIncGen",
":AffineOpsIncGen",
":EDSC",
":IR",
@ -677,6 +699,7 @@ cc_library(
deps = [
":CallOpInterfaces",
":CommonFolders",
":ControlFlowInterfaces",
":Dialect",
":IR",
":InferTypeOpInterface",
@ -1153,6 +1176,28 @@ cc_library(
],
)
cc_library(
name = "GPURuntimeTransforms",
srcs = [
"lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp",
"lib/Conversion/PassDetail.h",
],
hdrs = [
"include/mlir/Conversion/GPUCommon/GPUCommonPass.h",
],
includes = ["include"],
deps = [
":ConversionPassIncGen",
":GPUDialect",
":IR",
":LLVMDialect",
":Pass",
":Support",
"@llvm-project//llvm:core",
"@llvm-project//llvm:support",
],
)
gentbl(
name = "GPUToNVVMGen",
strip_include_prefix = "lib/Conversion/GPUToNVVM",
@ -1265,7 +1310,6 @@ cc_library(
name = "GPUToCUDATransforms",
srcs = [
"lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp",
"lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp",
"lib/Conversion/PassDetail.h",
],
hdrs = ["include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"],
@ -2446,6 +2490,7 @@ cc_library(
includes = ["include"],
deps = [
":Analysis",
":GPURuntimeTransforms",
":GPUToNVVMTransforms",
":GPUToROCDLTransforms",
":GPUToSPIRVTransforms",
@ -2525,6 +2570,7 @@ cc_library(
":ConversionPassIncGen",
":GPUDialect",
":GPUPassIncGen",
":GPURuntimeTransforms",
":GPUToCUDATransforms",
":GPUToNVVMTransforms",
":GPUToROCDLTransforms",
@ -2730,6 +2776,7 @@ cc_binary(
":AllPassesAndDialectsNoRegistration",
":ExecutionEngineUtils",
":GPUDialect",
":GPURuntimeTransforms",
":GPUToNVVMTransforms",
":GPUToROCDLTransforms",
":GPUTransforms",