Add EDSC support for loop.for operations
This CL adds support for loop.for operations in EDSC and adds a test. This will be used in a followup commit to implement lowering of vector_transfer ops so that it works more generally and is not subject to affine constraints. PiperOrigin-RevId: 275461067 Change-Id: I543d48895fb065d3a861835c949885a48a307c56
This commit is contained in:
parent
cb677d1cc1
commit
ca21a79e60
1
third_party/mlir/BUILD
vendored
1
third_party/mlir/BUILD
vendored
@ -173,6 +173,7 @@ cc_library(
|
|||||||
":AffineOps",
|
":AffineOps",
|
||||||
":Analysis",
|
":Analysis",
|
||||||
":IR",
|
":IR",
|
||||||
|
":LoopOps",
|
||||||
":StandardOps",
|
":StandardOps",
|
||||||
":Support",
|
":Support",
|
||||||
":TransformUtils",
|
":TransformUtils",
|
||||||
|
4
third_party/mlir/bindings/python/pybind.cpp
vendored
4
third_party/mlir/bindings/python/pybind.cpp
vendored
@ -306,7 +306,7 @@ struct PythonLoopContext {
|
|||||||
|
|
||||||
PythonValueHandle enter() {
|
PythonValueHandle enter() {
|
||||||
ValueHandle iv(lb.value.getType());
|
ValueHandle iv(lb.value.getType());
|
||||||
builder = new LoopBuilder(&iv, lb.value, ub.value, step);
|
builder = new AffineLoopNestBuilder(&iv, lb.value, ub.value, step);
|
||||||
return iv;
|
return iv;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,7 +318,7 @@ struct PythonLoopContext {
|
|||||||
|
|
||||||
PythonValueHandle lb, ub;
|
PythonValueHandle lb, ub;
|
||||||
int64_t step;
|
int64_t step;
|
||||||
LoopBuilder *builder = nullptr;
|
AffineLoopNestBuilder *builder = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PythonLoopNestContext {
|
struct PythonLoopNestContext {
|
||||||
|
35
third_party/mlir/include/mlir/EDSC/Builders.h
vendored
35
third_party/mlir/include/mlir/EDSC/Builders.h
vendored
@ -24,6 +24,7 @@
|
|||||||
#define MLIR_EDSC_BUILDERS_H_
|
#define MLIR_EDSC_BUILDERS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
@ -161,8 +162,14 @@ public:
|
|||||||
/// Constructs a new AffineForOp and captures the associated induction
|
/// Constructs a new AffineForOp and captures the associated induction
|
||||||
/// variable. A ValueHandle pointer is passed as the first argument and is the
|
/// variable. A ValueHandle pointer is passed as the first argument and is the
|
||||||
/// *only* way to capture the loop induction variable.
|
/// *only* way to capture the loop induction variable.
|
||||||
LoopBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
|
static LoopBuilder makeAffine(ValueHandle *iv,
|
||||||
|
ArrayRef<ValueHandle> lbHandles,
|
||||||
ArrayRef<ValueHandle> ubHandles, int64_t step);
|
ArrayRef<ValueHandle> ubHandles, int64_t step);
|
||||||
|
/// Constructs a new loop::ForOp and captures the associated induction
|
||||||
|
/// variable. A ValueHandle pointer is passed as the first argument and is the
|
||||||
|
/// *only* way to capture the loop induction variable.
|
||||||
|
static LoopBuilder makeLoop(ValueHandle *iv, ValueHandle lbHandle,
|
||||||
|
ValueHandle ubHandle, ValueHandle stepHandle);
|
||||||
LoopBuilder(const LoopBuilder &) = delete;
|
LoopBuilder(const LoopBuilder &) = delete;
|
||||||
LoopBuilder(LoopBuilder &&) = default;
|
LoopBuilder(LoopBuilder &&) = default;
|
||||||
|
|
||||||
@ -172,7 +179,10 @@ public:
|
|||||||
/// The only purpose of this operator is to serve as a sequence point so that
|
/// The only purpose of this operator is to serve as a sequence point so that
|
||||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||||
/// scoped within a LoopBuilder.
|
/// scoped within a LoopBuilder.
|
||||||
ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr);
|
void operator()(llvm::function_ref<void(void)> fun = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
LoopBuilder() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
|
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
|
||||||
@ -200,15 +210,34 @@ public:
|
|||||||
/// ```
|
/// ```
|
||||||
class AffineLoopNestBuilder {
|
class AffineLoopNestBuilder {
|
||||||
public:
|
public:
|
||||||
|
// This entry point accomodates the fact that AffineForOp implicitly uses
|
||||||
|
// multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max`
|
||||||
|
// and and `min` constraints respectively.
|
||||||
|
AffineLoopNestBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbs,
|
||||||
|
ArrayRef<ValueHandle> ubs, int64_t step);
|
||||||
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||||
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
||||||
|
|
||||||
ValueHandle operator()(llvm::function_ref<void(void)> fun = nullptr);
|
void operator()(llvm::function_ref<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<LoopBuilder, 4> loops;
|
SmallVector<LoopBuilder, 4> loops;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Helper class to sugar building loop.for loop nests from ranges.
|
||||||
|
/// This is similar to edsc::AffineLoopNestBuilder except it operates on
|
||||||
|
/// loop.for.
|
||||||
|
class LoopNestBuilder {
|
||||||
|
public:
|
||||||
|
LoopNestBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
||||||
|
ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
|
||||||
|
ArrayRef<ValueHandle> steps);
|
||||||
|
void operator()(std::function<void(void)> fun = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
llvm::SmallVector<LoopBuilder, 4> loops;
|
||||||
|
};
|
||||||
|
|
||||||
// This class exists solely to handle the C++ vexing parse case when
|
// This class exists solely to handle the C++ vexing parse case when
|
||||||
// trying to enter a Block that has already been constructed.
|
// trying to enter a Block that has already been constructed.
|
||||||
class Append {};
|
class Append {};
|
||||||
|
73
third_party/mlir/lib/EDSC/Builders.cpp
vendored
73
third_party/mlir/lib/EDSC/Builders.cpp
vendored
@ -162,12 +162,12 @@ static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
|
|||||||
ubConst.getValue(), step);
|
ubConst.getValue(), step);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
|
mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine(
|
||||||
ArrayRef<ValueHandle> lbHandles,
|
ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
|
||||||
ArrayRef<ValueHandle> ubHandles,
|
ArrayRef<ValueHandle> ubHandles, int64_t step) {
|
||||||
int64_t step) {
|
mlir::edsc::LoopBuilder result;
|
||||||
if (auto res = emitStaticFor(lbHandles, ubHandles, step)) {
|
if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) {
|
||||||
*iv = res.getValue();
|
*iv = staticFor.getValue();
|
||||||
} else {
|
} else {
|
||||||
SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end());
|
SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end());
|
||||||
SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end());
|
SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end());
|
||||||
@ -177,11 +177,24 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
|
|||||||
step);
|
step);
|
||||||
}
|
}
|
||||||
auto *body = getForInductionVarOwner(iv->getValue()).getBody();
|
auto *body = getForInductionVarOwner(iv->getValue()).getBody();
|
||||||
enter(body, /*prev=*/1);
|
result.enter(body, /*prev=*/1);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueHandle
|
mlir::edsc::LoopBuilder
|
||||||
mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
|
mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle,
|
||||||
|
ValueHandle ubHandle,
|
||||||
|
ValueHandle stepHandle) {
|
||||||
|
mlir::edsc::LoopBuilder result;
|
||||||
|
auto forOp =
|
||||||
|
OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle);
|
||||||
|
*iv = ValueHandle(forOp.getInductionVar());
|
||||||
|
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
|
||||||
|
result.enter(body, /*prev=*/1);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
|
||||||
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
||||||
// destructor) because of ordering wrt comma operator.
|
// destructor) because of ordering wrt comma operator.
|
||||||
/// The particular use case concerns nested blocks:
|
/// The particular use case concerns nested blocks:
|
||||||
@ -203,7 +216,12 @@ mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
|
|||||||
if (fun)
|
if (fun)
|
||||||
fun();
|
fun();
|
||||||
exit();
|
exit();
|
||||||
return ValueHandle::null();
|
}
|
||||||
|
|
||||||
|
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
|
||||||
|
ValueHandle *iv, ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
|
||||||
|
int64_t step) {
|
||||||
|
loops.emplace_back(LoopBuilder::makeAffine(iv, lbs, ubs, step));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
|
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
|
||||||
@ -212,13 +230,12 @@ mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
|
|||||||
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
|
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
|
||||||
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
|
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
|
||||||
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
|
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
|
||||||
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
|
for (auto it : llvm::zip(ivs, lbs, ubs, steps))
|
||||||
loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it),
|
loops.emplace_back(LoopBuilder::makeAffine(
|
||||||
std::get<3>(it));
|
std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it)));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueHandle mlir::edsc::AffineLoopNestBuilder::operator()(
|
void mlir::edsc::AffineLoopNestBuilder::operator()(
|
||||||
llvm::function_ref<void(void)> fun) {
|
llvm::function_ref<void(void)> fun) {
|
||||||
if (fun)
|
if (fun)
|
||||||
fun();
|
fun();
|
||||||
@ -227,10 +244,32 @@ ValueHandle mlir::edsc::AffineLoopNestBuilder::operator()(
|
|||||||
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
|
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
|
||||||
// occurs on calling operator()). The asymmetry is required for properly
|
// occurs on calling operator()). The asymmetry is required for properly
|
||||||
// nesting imperfectly nested regions (see LoopBuilder::operator()).
|
// nesting imperfectly nested regions (see LoopBuilder::operator()).
|
||||||
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
|
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit)
|
||||||
(*lit)();
|
(*lit)();
|
||||||
}
|
}
|
||||||
return ValueHandle::null();
|
|
||||||
|
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
|
||||||
|
ArrayRef<ValueHandle> lbs,
|
||||||
|
ArrayRef<ValueHandle> ubs,
|
||||||
|
ArrayRef<ValueHandle> steps) {
|
||||||
|
assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match");
|
||||||
|
assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match");
|
||||||
|
assert(ivs.size() == steps.size() &&
|
||||||
|
"expected size of ivs and steps to match");
|
||||||
|
loops.reserve(ivs.size());
|
||||||
|
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
|
||||||
|
loops.emplace_back(LoopBuilder::makeLoop(std::get<0>(it), std::get<1>(it),
|
||||||
|
std::get<2>(it), std::get<3>(it)));
|
||||||
|
}
|
||||||
|
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoopNestBuilder::LoopNestBuilder::operator()(
|
||||||
|
std::function<void(void)> fun) {
|
||||||
|
if (fun)
|
||||||
|
fun();
|
||||||
|
for (auto &lit : reverse(loops))
|
||||||
|
lit({});
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
|
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
|
||||||
|
1
third_party/mlir/lib/EDSC/CMakeLists.txt
vendored
1
third_party/mlir/lib/EDSC/CMakeLists.txt
vendored
@ -17,6 +17,7 @@ add_dependencies(MLIREDSC MLIRReferenceImplementationTestGen)
|
|||||||
target_link_libraries(MLIREDSC
|
target_link_libraries(MLIREDSC
|
||||||
PUBLIC
|
PUBLIC
|
||||||
MLIRAffineOps
|
MLIRAffineOps
|
||||||
|
MLIRLoopOps
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRVectorOps
|
MLIRVectorOps
|
||||||
|
Loading…
Reference in New Issue
Block a user