diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 1944ebddb35..be90d309615 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -173,6 +173,7 @@ cc_library( ":AffineOps", ":Analysis", ":IR", + ":LoopOps", ":StandardOps", ":Support", ":TransformUtils", diff --git a/third_party/mlir/bindings/python/pybind.cpp b/third_party/mlir/bindings/python/pybind.cpp index cc8fc6e8e76..258455d6867 100644 --- a/third_party/mlir/bindings/python/pybind.cpp +++ b/third_party/mlir/bindings/python/pybind.cpp @@ -306,7 +306,7 @@ struct PythonLoopContext { PythonValueHandle enter() { ValueHandle iv(lb.value.getType()); - builder = new LoopBuilder(&iv, lb.value, ub.value, step); + builder = new AffineLoopNestBuilder(&iv, lb.value, ub.value, step); return iv; } @@ -318,7 +318,7 @@ struct PythonLoopContext { PythonValueHandle lb, ub; int64_t step; - LoopBuilder *builder = nullptr; + AffineLoopNestBuilder *builder = nullptr; }; struct PythonLoopNestContext { diff --git a/third_party/mlir/include/mlir/EDSC/Builders.h b/third_party/mlir/include/mlir/EDSC/Builders.h index 9fbca8933ce..5a80571a161 100644 --- a/third_party/mlir/include/mlir/EDSC/Builders.h +++ b/third_party/mlir/include/mlir/EDSC/Builders.h @@ -24,6 +24,7 @@ #define MLIR_EDSC_BUILDERS_H_ #include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Builders.h" @@ -161,8 +162,14 @@ public: /// Constructs a new AffineForOp and captures the associated induction /// variable. A ValueHandle pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. - LoopBuilder(ValueHandle *iv, ArrayRef lbHandles, - ArrayRef ubHandles, int64_t step); + static LoopBuilder makeAffine(ValueHandle *iv, + ArrayRef lbHandles, + ArrayRef ubHandles, int64_t step); + /// Constructs a new loop::ForOp and captures the associated induction + /// variable. A ValueHandle pointer is passed as the first argument and is the + /// *only* way to capture the loop induction variable. + static LoopBuilder makeLoop(ValueHandle *iv, ValueHandle lbHandle, + ValueHandle ubHandle, ValueHandle stepHandle); LoopBuilder(const LoopBuilder &) = delete; LoopBuilder(LoopBuilder &&) = default; @@ -172,7 +179,10 @@ public: /// The only purpose of this operator is to serve as a sequence point so that /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// scoped within a LoopBuilder. - ValueHandle operator()(llvm::function_ref fun = nullptr); + void operator()(llvm::function_ref fun = nullptr); + +private: + LoopBuilder() = default; }; /// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid @@ -200,15 +210,34 @@ public: /// ``` class AffineLoopNestBuilder { public: + // This entry point accomodates the fact that AffineForOp implicitly uses + // multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max` + // and and `min` constraints respectively. + AffineLoopNestBuilder(ValueHandle *iv, ArrayRef lbs, + ArrayRef ubs, int64_t step); AffineLoopNestBuilder(ArrayRef ivs, ArrayRef lbs, ArrayRef ubs, ArrayRef steps); - ValueHandle operator()(llvm::function_ref fun = nullptr); + void operator()(llvm::function_ref fun = nullptr); private: SmallVector loops; }; +/// Helper class to sugar building loop.for loop nests from ranges. +/// This is similar to edsc::AffineLoopNestBuilder except it operates on +/// loop.for. +class LoopNestBuilder { +public: + LoopNestBuilder(llvm::ArrayRef ivs, + ArrayRef lbs, ArrayRef ubs, + ArrayRef steps); + void operator()(std::function fun = nullptr); + +private: + llvm::SmallVector loops; +}; + // This class exists solely to handle the C++ vexing parse case when // trying to enter a Block that has already been constructed. class Append {}; diff --git a/third_party/mlir/lib/EDSC/Builders.cpp b/third_party/mlir/lib/EDSC/Builders.cpp index ba14a147cf8..f1dceec8f11 100644 --- a/third_party/mlir/lib/EDSC/Builders.cpp +++ b/third_party/mlir/lib/EDSC/Builders.cpp @@ -162,12 +162,12 @@ static llvm::Optional emitStaticFor(ArrayRef lbs, ubConst.getValue(), step); } -mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, - ArrayRef lbHandles, - ArrayRef ubHandles, - int64_t step) { - if (auto res = emitStaticFor(lbHandles, ubHandles, step)) { - *iv = res.getValue(); +mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine( + ValueHandle *iv, ArrayRef lbHandles, + ArrayRef ubHandles, int64_t step) { + mlir::edsc::LoopBuilder result; + if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { + *iv = staticFor.getValue(); } else { SmallVector lbs(lbHandles.begin(), lbHandles.end()); SmallVector ubs(ubHandles.begin(), ubHandles.end()); @@ -177,11 +177,24 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, step); } auto *body = getForInductionVarOwner(iv->getValue()).getBody(); - enter(body, /*prev=*/1); + result.enter(body, /*prev=*/1); + return result; } -ValueHandle -mlir::edsc::LoopBuilder::operator()(llvm::function_ref fun) { +mlir::edsc::LoopBuilder +mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle, + ValueHandle ubHandle, + ValueHandle stepHandle) { + mlir::edsc::LoopBuilder result; + auto forOp = + OperationHandle::createOp(lbHandle, ubHandle, stepHandle); + *iv = ValueHandle(forOp.getInductionVar()); + auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody(); + result.enter(body, /*prev=*/1); + return result; +} + +void mlir::edsc::LoopBuilder::operator()(llvm::function_ref fun) { // Call to `exit` must be explicit and asymmetric (cannot happen in the // destructor) because of ordering wrt comma operator. /// The particular use case concerns nested blocks: @@ -203,7 +216,12 @@ mlir::edsc::LoopBuilder::operator()(llvm::function_ref fun) { if (fun) fun(); exit(); - return ValueHandle::null(); +} + +mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( + ValueHandle *iv, ArrayRef lbs, ArrayRef ubs, + int64_t step) { + loops.emplace_back(LoopBuilder::makeAffine(iv, lbs, ubs, step)); } mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( @@ -212,13 +230,12 @@ mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); - for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { - loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it), - std::get<3>(it)); - } + for (auto it : llvm::zip(ivs, lbs, ubs, steps)) + loops.emplace_back(LoopBuilder::makeAffine( + std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it))); } -ValueHandle mlir::edsc::AffineLoopNestBuilder::operator()( +void mlir::edsc::AffineLoopNestBuilder::operator()( llvm::function_ref fun) { if (fun) fun(); @@ -227,10 +244,32 @@ ValueHandle mlir::edsc::AffineLoopNestBuilder::operator()( // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() // occurs on calling operator()). The asymmetry is required for properly // nesting imperfectly nested regions (see LoopBuilder::operator()). - for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) { + for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) (*lit)(); +} + +mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef ivs, + ArrayRef lbs, + ArrayRef ubs, + ArrayRef steps) { + assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match"); + assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match"); + assert(ivs.size() == steps.size() && + "expected size of ivs and steps to match"); + loops.reserve(ivs.size()); + for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { + loops.emplace_back(LoopBuilder::makeLoop(std::get<0>(it), std::get<1>(it), + std::get<2>(it), std::get<3>(it))); } - return ValueHandle::null(); + assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); +} + +void LoopNestBuilder::LoopNestBuilder::operator()( + std::function fun) { + if (fun) + fun(); + for (auto &lit : reverse(loops)) + lit({}); } mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) { diff --git a/third_party/mlir/lib/EDSC/CMakeLists.txt b/third_party/mlir/lib/EDSC/CMakeLists.txt index 5f3dd9f1ee1..967d6add293 100644 --- a/third_party/mlir/lib/EDSC/CMakeLists.txt +++ b/third_party/mlir/lib/EDSC/CMakeLists.txt @@ -17,6 +17,7 @@ add_dependencies(MLIREDSC MLIRReferenceImplementationTestGen) target_link_libraries(MLIREDSC PUBLIC MLIRAffineOps + MLIRLoopOps MLIRStandardOps MLIRTransformUtils MLIRVectorOps