PiperOrigin-RevId: 310975358
Change-Id: I2d0978a23c371702c5d83266c8214223bf267630
This commit is contained in:
A. Unique TensorFlower 2020-05-11 13:01:44 -07:00 committed by TensorFlower Gardener
parent 5caba44997
commit 7a691ccd98
9 changed files with 102 additions and 104 deletions

View File

@ -695,9 +695,9 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation", "@llvm-project//mlir:Translation",

View File

@ -556,7 +556,7 @@ cc_library(
deps = [ deps = [
":tensorflow", ":tensorflow",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopOpsTransforms", "@llvm-project//mlir:SCFTransforms",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -240,8 +240,8 @@ cc_library(
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
], ],
@ -278,8 +278,8 @@ cc_library(
"@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
], ],

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
@ -112,7 +112,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
auto step = rewriter.create<mlir::ConstantOp>( auto step = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(), loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
auto loop = rewriter.create<mlir::loop::ForOp>(loc, zero, upper, step); auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
// Compute memrefs for the value to reduce. This makes it easier to just // Compute memrefs for the value to reduce. This makes it easier to just
@ -173,8 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
gpu::GPUDialect, loop::LoopOpsDialect, gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
XlaLhloDialect>();
target.addIllegalOp<ReduceOp>(); target.addIllegalOp<ReduceOp>();
auto func = getFunction(); auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext()); patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
@ -64,12 +64,12 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
// into a reduction operator of loop.reduce by doing buffer allocation for // into a reduction operator of loop.reduce by doing buffer allocation for
// scalar arguments and the result of `loop.reduce` to make it compatible with // scalar arguments and the result of `loop.reduce` to make it compatible with
// LHLO ops. // LHLO ops.
void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
Block* lhlo_block, OpBuilder* b) { Block* lhlo_block, OpBuilder* b) {
Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
OpBuilder::InsertionGuard guard(*b); OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(&loop_reduce_op_body); b->setInsertionPointToStart(&loop_reduce_op_body);
b->create<loop::ReduceReturnOp>( b->create<scf::ReduceReturnOp>(
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
lhlo_block, b)); lhlo_block, b));
} }
@ -136,9 +136,9 @@ MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
return mapped_ivs; return mapped_ivs;
} }
// Returns loop::Parallel over a shaped value with static or dynamic shape. // Returns scf::Parallel over a shaped value with static or dynamic shape.
loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
OpBuilder* b) { OpBuilder* b) {
Value zero = b->create<ConstantIndexOp>(loc, 0); Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1); Value one = b->create<ConstantIndexOp>(loc, 1);
@ -151,10 +151,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
lower.push_back(zero); lower.push_back(zero);
step.push_back(one); step.push_back(one);
} }
return b->create<loop::ParallelOp>(loc, lower, upper, step); return b->create<scf::ParallelOp>(loc, lower, upper, step);
} }
// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. // Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are // The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
// contains the reduction operator. // contains the reduction operator.
@ -197,7 +197,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// TODO(b/137624192) Implement variadic reduce. // TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure(); if (xla_reduce_op.out().size() != 1) return failure();
loop::ReduceOp reduce_op = scf::ReduceOp reduce_op =
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
&xla_reduce_op.body().front(), &rewriter); &xla_reduce_op.body().front(), &rewriter);
@ -225,7 +225,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// } : f32 // } : f32
// loop.yield // loop.yield
// } // }
loop::ReduceOp CreateReduceOpInNestedParallelLoops( scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op, xla_lhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc(); auto loc = xla_reduce_op.getLoc();
@ -254,13 +254,13 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
SmallVector<Value, 1> init_value = { SmallVector<Value, 1> init_value = {
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())}; rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims. // Outer ParallelOp is not needed if it is a reduction across all dims.
loop::ParallelOp outer; scf::ParallelOp outer;
if (!parallel_lower.empty()) { if (!parallel_lower.empty()) {
outer = rewriter->create<loop::ParallelOp>(loc, parallel_lower, outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
parallel_upper, parallel_step); parallel_upper, parallel_step);
rewriter->setInsertionPointToStart(outer.getBody()); rewriter->setInsertionPointToStart(outer.getBody());
} }
loop::ParallelOp inner = rewriter->create<loop::ParallelOp>( scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
loc, reduce_lower, reduce_upper, reduce_step, init_value); loc, reduce_lower, reduce_upper, reduce_step, init_value);
Value reduction_result = *inner.getResults().begin(); Value reduction_result = *inner.getResults().begin();
@ -294,7 +294,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
rewriter->setInsertionPointToStart(inner.getBody()); rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>( Value elem = rewriter->create<mlir::LoadOp>(
loc, *xla_reduce_op.operands().begin(), indices); loc, *xla_reduce_op.operands().begin(), indices);
return rewriter->create<loop::ReduceOp>(loc, elem); return rewriter->create<scf::ReduceOp>(loc, elem);
} }
}; };
@ -314,8 +314,8 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// accumulator = reduction_operator(output[O], value) // accumulator = reduction_operator(output[O], value)
// output[O] = accumulator // output[O] = accumulator
// //
// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a // Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
// loop::ReduceOp. // scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output // The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
// reduction windows and `ReduceOp` contains the reduction operator. // reduction windows and `ReduceOp` contains the reduction operator.
@ -366,12 +366,12 @@ class ReduceWindowOpConverter
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/, xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
loop::ParallelOp output_loop, window_loop; scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) = std::tie(output_loop, window_loop) =
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
&rewriter); &rewriter);
loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
xla_reduce_window_op, output_loop, window_loop, &rewriter); xla_reduce_window_op, output_loop, window_loop, &rewriter);
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
@ -381,7 +381,7 @@ class ReduceWindowOpConverter
} }
private: private:
std::pair<loop::ParallelOp, loop::ParallelOp> std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow( CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, xla_lhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
@ -405,7 +405,7 @@ class ReduceWindowOpConverter
window_upper.push_back( window_upper.push_back(
rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue())); rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
} }
auto window_loop = rewriter->create<loop::ParallelOp>( auto window_loop = rewriter->create<scf::ParallelOp>(
loc, window_lower, window_upper, window_step, init_value); loc, window_lower, window_upper, window_step, init_value);
Value reduction_result = *window_loop.getResults().begin(); Value reduction_result = *window_loop.getResults().begin();
@ -414,9 +414,9 @@ class ReduceWindowOpConverter
return std::make_pair(output_loop, window_loop); return std::make_pair(output_loop, window_loop);
} }
loop::ReduceOp CreateReduceOpInNestedParallelLoops( scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, xla_lhlo::ReduceWindowOp xla_reduce_window_op,
loop::ParallelOp output_loop, loop::ParallelOp window_loop, scf::ParallelOp output_loop, scf::ParallelOp window_loop,
ConversionPatternRewriter* rewriter) const { ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody()); rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc(); auto loc = xla_reduce_window_op.getLoc();
@ -436,20 +436,20 @@ class ReduceWindowOpConverter
xla_reduce_window_op, output_loop.getInductionVars(), xla_reduce_window_op, output_loop.getInductionVars(),
window_loop.getInductionVars(), rewriter); window_loop.getInductionVars(), rewriter);
auto elem_or_init = rewriter->create<loop::IfOp>( auto elem_or_init = rewriter->create<scf::IfOp>(
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true); /*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>( Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<loop::YieldOp>(loc, elem); then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
else_builder.create<loop::YieldOp>(loc, *window_loop.initVals().begin()); else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
return rewriter->create<loop::ReduceOp>(loc, return rewriter->create<scf::ReduceOp>(loc,
*elem_or_init.results().begin()); *elem_or_init.results().begin());
} }
}; };
@ -490,7 +490,7 @@ class SelectAndScatterOpConverter
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter); InitializeOutput(s_and_s_op, &rewriter);
loop::ParallelOp loop_over_src = scf::ParallelOp loop_over_src =
MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter); MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
rewriter.setInsertionPointToStart(loop_over_src.getBody()); rewriter.setInsertionPointToStart(loop_over_src.getBody());
@ -520,7 +520,7 @@ class SelectAndScatterOpConverter
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value()); Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
loop::ParallelOp loop_over_output = scf::ParallelOp loop_over_output =
MakeLoopOverShape(loc, s_and_s_op.out(), b); MakeLoopOverShape(loc, s_and_s_op.out(), b);
OpBuilder::InsertionGuard guard(*b); OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(loop_over_output.getBody()); b->setInsertionPointToStart(loop_over_output.getBody());
@ -531,10 +531,10 @@ class SelectAndScatterOpConverter
struct WindowLoops { struct WindowLoops {
SmallVector<Value, 2> selected_ivs; SmallVector<Value, 2> selected_ivs;
SmallVector<Value, 2> window_ivs; SmallVector<Value, 2> window_ivs;
loop::ForOp inner_loop; scf::ForOp inner_loop;
}; };
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
loop::ParallelOp loop_over_src, scf::ParallelOp loop_over_src,
OpBuilder* b) const { OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
Value zero = b->create<ConstantIndexOp>(loc, 0); Value zero = b->create<ConstantIndexOp>(loc, 0);
@ -558,12 +558,12 @@ class SelectAndScatterOpConverter
s_and_s_op.window_dimensions()->getIntValues()) { s_and_s_op.window_dimensions()->getIntValues()) {
Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue()); Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
result.inner_loop = result.inner_loop =
b->create<loop::ForOp>(loc, zero, upper, one, iter_args); b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
if (b->getInsertionBlock() == loop_over_src.getBody()) { if (b->getInsertionBlock() == loop_over_src.getBody()) {
ip = b->saveInsertionPoint(); ip = b->saveInsertionPoint();
result.selected_ivs = result.inner_loop.getResults().take_front(rank); result.selected_ivs = result.inner_loop.getResults().take_front(rank);
} else { } else {
b->create<loop::YieldOp>(loc, result.inner_loop.getResults()); b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
} }
b->setInsertionPointToStart(result.inner_loop.getBody()); b->setInsertionPointToStart(result.inner_loop.getBody());
iter_args = ValueRange{result.inner_loop.getRegionIterArgs()}; iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
@ -599,7 +599,7 @@ class SelectAndScatterOpConverter
}; };
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
loop::ParallelOp loop_over_src, scf::ParallelOp loop_over_src,
OpBuilder* b) const { OpBuilder* b) const {
auto loc = s_and_s_op.getLoc(); auto loc = s_and_s_op.getLoc();
@ -614,7 +614,7 @@ class SelectAndScatterOpConverter
IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
auto if_in_bounds = inner_loop_b.create<loop::IfOp>( auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds, loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
/*withElseRegion=*/true); /*withElseRegion=*/true);
@ -623,16 +623,16 @@ class SelectAndScatterOpConverter
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
auto select_or_init_results = SelectOrInitialize( auto select_or_init_results = SelectOrInitialize(
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
in_bounds_then_b.create<loop::YieldOp>(loc, select_or_init_results); in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
} }
// Case when we are in the pad. // Case when we are in the pad.
{ {
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
in_bounds_else_b.create<loop::YieldOp>(loc, ivs_val_flag.to_vector()); in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
} }
inner_loop_b.create<loop::YieldOp>(loc, if_in_bounds.getResults()); inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
return window_loops.selected_ivs; return window_loops.selected_ivs;
} }
@ -647,8 +647,8 @@ class SelectAndScatterOpConverter
Value operand_elem = Value operand_elem =
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs); b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
auto if_init = auto if_init =
b->create<loop::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(), b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
/*withElseRegion=*/true); /*withElseRegion=*/true);
// Init == true, i.e. iter args are already initialized with a selected // Init == true, i.e. iter args are already initialized with a selected
// element in boundaries of the operand. Select function has to be computed // element in boundaries of the operand. Select function has to be computed
// here. // here.
@ -660,32 +660,31 @@ class SelectAndScatterOpConverter
ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()}, ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
&lhlo_select, &if_init_then_b); &lhlo_select, &if_init_then_b);
auto if_pred = auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
if_init_then_b.create<loop::IfOp>(loc, iter_arg_types, pred, /*withElseRegion=*/true);
/*withElseRegion=*/true);
// Pred == true, therefore pack newly selected ivs, val and init flag back // Pred == true, therefore pack newly selected ivs, val and init flag back
// to iter_args and return. // to iter_args and return.
{ {
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
if_pred_then_b.create<loop::YieldOp>( if_pred_then_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
} }
// Pred == false, therefore return old iter_args. // Pred == false, therefore return old iter_args.
{ {
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
if_pred_else_b.create<loop::YieldOp>(loc, ivs_val_flag->to_vector()); if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
} }
if_init_then_b.create<loop::YieldOp>(loc, if_pred.getResults()); if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
} }
// Init == false, i.e. only pad was visited before and this is the first // Init == false, i.e. only pad was visited before and this is the first
// element in the boundaries of the operand. // element in the boundaries of the operand.
{ {
OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
if_init_else_b.create<loop::YieldOp>( if_init_else_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
} }
return if_init.getResults(); return if_init.getResults();
@ -708,7 +707,7 @@ struct LhloLegalizeToParallelLoops
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
loop::LoopOpsDialect, XlaLhloDialect>(); scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp, target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>(); xla_lhlo::SelectAndScatterOp>();

View File

@ -185,11 +185,11 @@ cc_library(
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:LoopsToGPUPass", "@llvm-project//mlir:LoopsToGPUPass",
"@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",

View File

@ -31,9 +31,9 @@ limitations under the License.
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
#include "mlir/Dialect/LoopOps/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/LoopOps/Transforms.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
@ -132,7 +132,7 @@ struct StoreForwardingPass
// No store operation found. Continue search outside of the parallel // No store operation found. Continue search outside of the parallel
// loop if block is in a parallel loop. // loop if block is in a parallel loop.
if (auto parallelOp = if (auto parallelOp =
llvm::dyn_cast<mlir::loop::ParallelOp>(block->getParentOp())) { llvm::dyn_cast<mlir::scf::ParallelOp>(block->getParentOp())) {
return findStore(parallelOp.getOperation(), matches); return findStore(parallelOp.getOperation(), matches);
} }
return {}; return {};
@ -388,8 +388,8 @@ struct MapParallelLoops
struct FuseInnerParallelLoops struct FuseInnerParallelLoops
: public mlir::PassWrapper<FuseInnerParallelLoops, mlir::FunctionPass> { : public mlir::PassWrapper<FuseInnerParallelLoops, mlir::FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
getFunction().walk([](mlir::loop::ParallelOp op) { getFunction().walk([](mlir::scf::ParallelOp op) {
mlir::loop::naivelyFuseParallelOps(op.region()); mlir::scf::naivelyFuseParallelOps(op.region());
}); });
} }
}; };
@ -401,7 +401,7 @@ struct ParallelLoopCollapsingToFirstDim
void runOnOperation() override { void runOnOperation() override {
mlir::Operation* module = getOperation(); mlir::Operation* module = getOperation();
module->walk([&](mlir::loop::ParallelOp op) { module->walk([&](mlir::scf::ParallelOp op) {
unsigned num_loops = op.getNumLoops(); unsigned num_loops = op.getNumLoops();
std::vector<unsigned> combinedLoops; std::vector<unsigned> combinedLoops;
combinedLoops.reserve(num_loops); combinedLoops.reserve(num_loops);

View File

@ -297,9 +297,9 @@ cc_library(
) )
filegroup( filegroup(
name = "LoopOpsTdFiles", name = "SCFTdFiles",
srcs = [ srcs = [
"include/mlir/Dialect/LoopOps/LoopOps.td", "include/mlir/Dialect/SCF/SCFOps.td",
"include/mlir/Interfaces/ControlFlowInterfaces.td", "include/mlir/Interfaces/ControlFlowInterfaces.td",
"include/mlir/Interfaces/LoopLikeInterface.td", "include/mlir/Interfaces/LoopLikeInterface.td",
"include/mlir/Interfaces/SideEffects.td", "include/mlir/Interfaces/SideEffects.td",
@ -308,26 +308,26 @@ filegroup(
) )
gentbl( gentbl(
name = "LoopOpsIncGen", name = "SCFIncGen",
strip_include_prefix = "include", strip_include_prefix = "include",
tbl_outs = [ tbl_outs = [
( (
"-gen-op-decls", "-gen-op-decls",
"include/mlir/Dialect/LoopOps/LoopOps.h.inc", "include/mlir/Dialect/SCF/SCFOps.h.inc",
), ),
( (
"-gen-op-defs", "-gen-op-defs",
"include/mlir/Dialect/LoopOps/LoopOps.cpp.inc", "include/mlir/Dialect/SCF/SCFOps.cpp.inc",
), ),
( (
"-gen-dialect-decls", "-gen-dialect-decls",
"include/mlir/Dialect/LoopOps/LoopOpsDialect.h.inc", "include/mlir/Dialect/SCF/SCFOpsDialect.h.inc",
), ),
], ],
tblgen = ":mlir-tblgen", tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/LoopOps/LoopOps.td", td_file = "include/mlir/Dialect/SCF/SCFOps.td",
td_srcs = [ td_srcs = [
":LoopOpsTdFiles", ":SCFTdFiles",
], ],
) )
@ -337,30 +337,30 @@ gentbl(
tbl_outs = [ tbl_outs = [
( (
"-gen-pass-decls", "-gen-pass-decls",
"include/mlir/Dialect/LoopOps/Passes.h.inc", "include/mlir/Dialect/SCF/Passes.h.inc",
), ),
], ],
tblgen = ":mlir-tblgen", tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/LoopOps/Passes.td", td_file = "include/mlir/Dialect/SCF/Passes.td",
td_srcs = [ td_srcs = [
":PassBaseTdFiles", ":PassBaseTdFiles",
], ],
) )
cc_library( cc_library(
name = "LoopOpsTransforms", name = "SCFTransforms",
srcs = glob([ srcs = glob([
"lib/Dialect/LoopOps/Transforms/*.cpp", "lib/Dialect/SCF/Transforms/*.cpp",
"lib/Dialect/LoopOps/Transforms/*.h", "lib/Dialect/SCF/Transforms/*.h",
]), ]),
hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"], hdrs = ["include/mlir/Dialect/SCF/Passes.h"],
includes = ["include"], includes = ["include"],
deps = [ deps = [
":Affine", ":Affine",
":IR", ":IR",
":LoopOps",
":LoopPassIncGen", ":LoopPassIncGen",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Transforms", ":Transforms",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
@ -521,8 +521,8 @@ cc_library(
":AffinePassIncGen", ":AffinePassIncGen",
":Analysis", ":Analysis",
":IR", ":IR",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":Transforms", ":Transforms",
@ -559,8 +559,8 @@ cc_library(
":Affine", ":Affine",
":ConversionPassIncGen", ":ConversionPassIncGen",
":IR", ":IR",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":Transforms", ":Transforms",
@ -588,17 +588,17 @@ cc_library(
) )
cc_library( cc_library(
name = "LoopOps", name = "SCFDialect",
srcs = glob( srcs = glob(
[ [
"lib/Dialect/LoopOps/*.cpp", "lib/Dialect/SCF/*.cpp",
"lib/Dialect/LoopOps/*.h", "lib/Dialect/SCF/*.h",
"lib/Dialect/LoopOps/EDSC/*.cpp", "lib/Dialect/SCF/EDSC/*.cpp",
], ],
), ),
hdrs = glob([ hdrs = glob([
"include/mlir/Dialect/LoopOps/*.h", "include/mlir/Dialect/SCF/*.h",
"include/mlir/Dialect/LoopOps/EDSC/*.h", "include/mlir/Dialect/SCF/EDSC/*.h",
]), ]),
includes = ["include"], includes = ["include"],
deps = [ deps = [
@ -606,7 +606,7 @@ cc_library(
":EDSC", ":EDSC",
":IR", ":IR",
":LoopLikeInterface", ":LoopLikeInterface",
":LoopOpsIncGen", ":SCFIncGen",
":SideEffects", ":SideEffects",
":StandardOps", ":StandardOps",
":Support", ":Support",
@ -1113,9 +1113,9 @@ cc_library(
":GPUDialect", ":GPUDialect",
":GPUPassIncGen", ":GPUPassIncGen",
":IR", ":IR",
":LoopOps",
":ParallelLoopMapperAttrGen", ":ParallelLoopMapperAttrGen",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":Transforms", ":Transforms",
@ -1324,8 +1324,8 @@ cc_library(
":GPUDialect", ":GPUDialect",
":GPUToSPIRVIncGen", ":GPUToSPIRVIncGen",
":IR", ":IR",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":SPIRVDialect", ":SPIRVDialect",
":SPIRVLowering", ":SPIRVLowering",
":StandardToSPIRVConversions", ":StandardToSPIRVConversions",
@ -1883,7 +1883,7 @@ cc_library(
":ControlFlowInterfaces", ":ControlFlowInterfaces",
":IR", ":IR",
":LoopLikeInterface", ":LoopLikeInterface",
":LoopOps", ":SCFDialect",
":SideEffects", ":SideEffects",
":StandardOps", ":StandardOps",
":Support", ":Support",
@ -2000,8 +2000,8 @@ cc_library(
":ControlFlowInterfaces", ":ControlFlowInterfaces",
":IR", ":IR",
":LoopLikeInterface", ":LoopLikeInterface",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":SideEffects", ":SideEffects",
":StandardOps", ":StandardOps",
":Support", ":Support",
@ -2037,8 +2037,8 @@ cc_library(
":GPUDialect", ":GPUDialect",
":GPUTransforms", ":GPUTransforms",
":IR", ":IR",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":TransformUtils", ":TransformUtils",
@ -2061,9 +2061,9 @@ cc_library(
":Affine", ":Affine",
":ConversionPassIncGen", ":ConversionPassIncGen",
":GPUDialect", ":GPUDialect",
":LoopOps",
":LoopsToGPU", ":LoopsToGPU",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":Transforms", ":Transforms",
@ -2085,8 +2085,8 @@ cc_library(
":ConversionPassIncGen", ":ConversionPassIncGen",
":IR", ":IR",
":LLVMDialect", ":LLVMDialect",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":TransformUtils", ":TransformUtils",
@ -2292,7 +2292,7 @@ cc_library(
":Affine", ":Affine",
":CallOpInterfaces", ":CallOpInterfaces",
":IR", ":IR",
":LoopOps", ":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
@ -2479,10 +2479,10 @@ cc_library(
":LLVMTransforms", ":LLVMTransforms",
":LinalgToLLVM", ":LinalgToLLVM",
":LinalgToSPIRV", ":LinalgToSPIRV",
":LoopOpsTransforms",
":NVVMDialect", ":NVVMDialect",
":Parser", ":Parser",
":Pass", ":Pass",
":SCFTransforms",
":StandardOpsTransforms", ":StandardOpsTransforms",
":StandardToSPIRVConversions", ":StandardToSPIRVConversions",
":StandardToStandard", ":StandardToStandard",
@ -2566,8 +2566,6 @@ cc_library(
":LinalgToLLVM", ":LinalgToLLVM",
":LinalgToSPIRV", ":LinalgToSPIRV",
":LinalgTransforms", ":LinalgTransforms",
":LoopOps",
":LoopOpsTransforms",
":LoopPassIncGen", ":LoopPassIncGen",
":LoopsToGPUPass", ":LoopsToGPUPass",
":NVVMDialect", ":NVVMDialect",
@ -2575,6 +2573,8 @@ cc_library(
":QuantOps", ":QuantOps",
":QuantPassIncGen", ":QuantPassIncGen",
":ROCDLDialect", ":ROCDLDialect",
":SCFDialect",
":SCFTransforms",
":SDBM", ":SDBM",
":SPIRVDialect", ":SPIRVDialect",
":SPIRVLowering", ":SPIRVLowering",
@ -3245,8 +3245,8 @@ cc_library(
":LinalgOps", ":LinalgOps",
":LinalgPassIncGen", ":LinalgPassIncGen",
":LinalgStructuredOpsIncGen", ":LinalgStructuredOpsIncGen",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":TransformUtils", ":TransformUtils",
@ -3367,8 +3367,8 @@ cc_library(
":IR", ":IR",
":LLVMDialect", ":LLVMDialect",
":LLVMTransforms", ":LLVMTransforms",
":LoopOps",
":Pass", ":Pass",
":SCFDialect",
":StandardOps", ":StandardOps",
":Support", ":Support",
":Transforms", ":Transforms",

View File

@ -163,8 +163,8 @@ cc_library(
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:TransformUtils",