Add EDSC support for loop.for operations

This CL adds support for loop.for operations in EDSC and adds a test.
This will be used in a followup commit to implement lowering of vector_transfer ops so that it works more generally and is not subject to affine constraints.

PiperOrigin-RevId: 275461067
Change-Id: I543d48895fb065d3a861835c949885a48a307c56
This commit is contained in:
Nicolas Vasilache 2019-10-18 06:44:41 -07:00 committed by TensorFlower Gardener
parent cb677d1cc1
commit ca21a79e60
5 changed files with 93 additions and 23 deletions

View File

@ -173,6 +173,7 @@ cc_library(
":AffineOps", ":AffineOps",
":Analysis", ":Analysis",
":IR", ":IR",
":LoopOps",
":StandardOps", ":StandardOps",
":Support", ":Support",
":TransformUtils", ":TransformUtils",

View File

@ -306,7 +306,7 @@ struct PythonLoopContext {
PythonValueHandle enter() { PythonValueHandle enter() {
ValueHandle iv(lb.value.getType()); ValueHandle iv(lb.value.getType());
builder = new LoopBuilder(&iv, lb.value, ub.value, step); builder = new AffineLoopNestBuilder(&iv, lb.value, ub.value, step);
return iv; return iv;
} }
@ -318,7 +318,7 @@ struct PythonLoopContext {
PythonValueHandle lb, ub; PythonValueHandle lb, ub;
int64_t step; int64_t step;
LoopBuilder *builder = nullptr; AffineLoopNestBuilder *builder = nullptr;
}; };
struct PythonLoopNestContext { struct PythonLoopNestContext {

View File

@ -24,6 +24,7 @@
#define MLIR_EDSC_BUILDERS_H_ #define MLIR_EDSC_BUILDERS_H_
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
@ -161,8 +162,14 @@ public:
/// Constructs a new AffineForOp and captures the associated induction /// Constructs a new AffineForOp and captures the associated induction
/// variable. A ValueHandle pointer is passed as the first argument and is the /// variable. A ValueHandle pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable. /// *only* way to capture the loop induction variable.
LoopBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbHandles, static LoopBuilder makeAffine(ValueHandle *iv,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles, int64_t step); ArrayRef<ValueHandle> ubHandles, int64_t step);
/// Constructs a new loop::ForOp and captures the associated induction
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
static LoopBuilder makeLoop(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle, ValueHandle stepHandle);
LoopBuilder(const LoopBuilder &) = delete; LoopBuilder(const LoopBuilder &) = delete;
LoopBuilder(LoopBuilder &&) = default; LoopBuilder(LoopBuilder &&) = default;
@ -172,7 +179,10 @@ public:
/// The only purpose of this operator is to serve as a sequence point so that /// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// scoped within a LoopBuilder. /// scoped within a LoopBuilder.
ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr); void operator()(llvm::function_ref<void(void)> fun = nullptr);
private:
LoopBuilder() = default;
}; };
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid /// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
@ -200,15 +210,34 @@ public:
/// ``` /// ```
class AffineLoopNestBuilder { class AffineLoopNestBuilder {
public: public:
// This entry point accomodates the fact that AffineForOp implicitly uses
// multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max`
// and and `min` constraints respectively.
AffineLoopNestBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, int64_t step);
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs, AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps); ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr); void operator()(llvm::function_ref<void(void)> fun = nullptr);
private: private:
SmallVector<LoopBuilder, 4> loops; SmallVector<LoopBuilder, 4> loops;
}; };
/// Helper class to sugar building loop.for loop nests from ranges.
/// This is similar to edsc::AffineLoopNestBuilder except it operates on
/// loop.for.
class LoopNestBuilder {
public:
LoopNestBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
ArrayRef<ValueHandle> steps);
void operator()(std::function<void(void)> fun = nullptr);
private:
llvm::SmallVector<LoopBuilder, 4> loops;
};
// This class exists solely to handle the C++ vexing parse case when // This class exists solely to handle the C++ vexing parse case when
// trying to enter a Block that has already been constructed. // trying to enter a Block that has already been constructed.
class Append {}; class Append {};

View File

@ -162,12 +162,12 @@ static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
ubConst.getValue(), step); ubConst.getValue(), step);
} }
mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine(
ArrayRef<ValueHandle> lbHandles, ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles, ArrayRef<ValueHandle> ubHandles, int64_t step) {
int64_t step) { mlir::edsc::LoopBuilder result;
if (auto res = emitStaticFor(lbHandles, ubHandles, step)) { if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) {
*iv = res.getValue(); *iv = staticFor.getValue();
} else { } else {
SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end()); SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end());
SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end()); SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end());
@ -177,11 +177,24 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
step); step);
} }
auto *body = getForInductionVarOwner(iv->getValue()).getBody(); auto *body = getForInductionVarOwner(iv->getValue()).getBody();
enter(body, /*prev=*/1); result.enter(body, /*prev=*/1);
return result;
} }
ValueHandle mlir::edsc::LoopBuilder
mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) { mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle,
ValueHandle stepHandle) {
mlir::edsc::LoopBuilder result;
auto forOp =
OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle);
*iv = ValueHandle(forOp.getInductionVar());
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
result.enter(body, /*prev=*/1);
return result;
}
void mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the // Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator. // destructor) because of ordering wrt comma operator.
/// The particular use case concerns nested blocks: /// The particular use case concerns nested blocks:
@ -203,7 +216,12 @@ mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
if (fun) if (fun)
fun(); fun();
exit(); exit();
return ValueHandle::null(); }
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
ValueHandle *iv, ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
int64_t step) {
loops.emplace_back(LoopBuilder::makeAffine(iv, lbs, ubs, step));
} }
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
@ -212,13 +230,12 @@ mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { for (auto it : llvm::zip(ivs, lbs, ubs, steps))
loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it), loops.emplace_back(LoopBuilder::makeAffine(
std::get<3>(it)); std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it)));
}
} }
ValueHandle mlir::edsc::AffineLoopNestBuilder::operator()( void mlir::edsc::AffineLoopNestBuilder::operator()(
llvm::function_ref<void(void)> fun) { llvm::function_ref<void(void)> fun) {
if (fun) if (fun)
fun(); fun();
@ -227,10 +244,32 @@ ValueHandle mlir::edsc::AffineLoopNestBuilder::operator()(
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
// occurs on calling operator()). The asymmetry is required for properly // occurs on calling operator()). The asymmetry is required for properly
// nesting imperfectly nested regions (see LoopBuilder::operator()). // nesting imperfectly nested regions (see LoopBuilder::operator()).
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) { for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit)
(*lit)(); (*lit)();
} }
return ValueHandle::null();
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
ArrayRef<ValueHandle> steps) {
assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match");
assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match");
assert(ivs.size() == steps.size() &&
"expected size of ivs and steps to match");
loops.reserve(ivs.size());
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
loops.emplace_back(LoopBuilder::makeLoop(std::get<0>(it), std::get<1>(it),
std::get<2>(it), std::get<3>(it)));
}
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
void LoopNestBuilder::LoopNestBuilder::operator()(
std::function<void(void)> fun) {
if (fun)
fun();
for (auto &lit : reverse(loops))
lit({});
} }
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) { mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {

View File

@ -17,6 +17,7 @@ add_dependencies(MLIREDSC MLIRReferenceImplementationTestGen)
target_link_libraries(MLIREDSC target_link_libraries(MLIREDSC
PUBLIC PUBLIC
MLIRAffineOps MLIRAffineOps
MLIRLoopOps
MLIRStandardOps MLIRStandardOps
MLIRTransformUtils MLIRTransformUtils
MLIRVectorOps MLIRVectorOps