NFC: Remove unnecessary 'llvm::' prefix from uses of llvm symbols declared in mlir
namespace.
Aside from being cleaner, this also makes the codebase more consistent. PiperOrigin-RevId: 286206974 Change-Id: I2eb3b84bfa317f1f3e3f04aeabe65e8e752ebb4a
This commit is contained in:
parent
7c6c941f87
commit
6103719ffa
1
third_party/mlir/BUILD
vendored
1
third_party/mlir/BUILD
vendored
@ -2362,6 +2362,7 @@ cc_library(
|
|||||||
":IR",
|
":IR",
|
||||||
":QuantOps",
|
":QuantOps",
|
||||||
":StandardOps",
|
":StandardOps",
|
||||||
|
":Support",
|
||||||
"@llvm//:support",
|
"@llvm//:support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||||
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/Optional.h"
|
#include "llvm/ADT/Optional.h"
|
||||||
@ -41,9 +42,8 @@ class Value;
|
|||||||
/// Returns in `affineApplyOps`, the sequence of those AffineApplyOp
|
/// Returns in `affineApplyOps`, the sequence of those AffineApplyOp
|
||||||
/// Operations that are reachable via a search starting from `operands` and
|
/// Operations that are reachable via a search starting from `operands` and
|
||||||
/// ending at those operands that are not the result of an AffineApplyOp.
|
/// ending at those operands that are not the result of an AffineApplyOp.
|
||||||
void getReachableAffineApplyOps(
|
void getReachableAffineApplyOps(ArrayRef<Value *> operands,
|
||||||
llvm::ArrayRef<Value *> operands,
|
SmallVectorImpl<Operation *> &affineApplyOps);
|
||||||
llvm::SmallVectorImpl<Operation *> &affineApplyOps);
|
|
||||||
|
|
||||||
/// Builds a system of constraints with dimensional identifiers corresponding to
|
/// Builds a system of constraints with dimensional identifiers corresponding to
|
||||||
/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
|
/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
|
||||||
@ -51,14 +51,14 @@ void getReachableAffineApplyOps(
|
|||||||
/// operands are added as symbols in the system. Returns failure for the yet
|
/// operands are added as symbols in the system. Returns failure for the yet
|
||||||
/// unimplemented cases.
|
/// unimplemented cases.
|
||||||
// TODO(bondhugula): handle non-unit strides.
|
// TODO(bondhugula): handle non-unit strides.
|
||||||
LogicalResult getIndexSet(llvm::MutableArrayRef<AffineForOp> forOps,
|
LogicalResult getIndexSet(MutableArrayRef<AffineForOp> forOps,
|
||||||
FlatAffineConstraints *domain);
|
FlatAffineConstraints *domain);
|
||||||
|
|
||||||
/// Encapsulates a memref load or store access information.
|
/// Encapsulates a memref load or store access information.
|
||||||
struct MemRefAccess {
|
struct MemRefAccess {
|
||||||
Value *memref;
|
Value *memref;
|
||||||
Operation *opInst;
|
Operation *opInst;
|
||||||
llvm::SmallVector<Value *, 4> indices;
|
SmallVector<Value *, 4> indices;
|
||||||
|
|
||||||
/// Constructs a MemRefAccess from a load or store operation.
|
/// Constructs a MemRefAccess from a load or store operation.
|
||||||
// TODO(b/119949820): add accessors to standard op's load, store, DMA op's to
|
// TODO(b/119949820): add accessors to standard op's load, store, DMA op's to
|
||||||
@ -94,9 +94,9 @@ struct DependenceComponent {
|
|||||||
// The AffineForOp Operation associated with this dependence component.
|
// The AffineForOp Operation associated with this dependence component.
|
||||||
Operation *op;
|
Operation *op;
|
||||||
// The lower bound of the dependence distance.
|
// The lower bound of the dependence distance.
|
||||||
llvm::Optional<int64_t> lb;
|
Optional<int64_t> lb;
|
||||||
// The upper bound of the dependence distance (inclusive).
|
// The upper bound of the dependence distance (inclusive).
|
||||||
llvm::Optional<int64_t> ub;
|
Optional<int64_t> ub;
|
||||||
DependenceComponent() : lb(llvm::None), ub(llvm::None) {}
|
DependenceComponent() : lb(llvm::None), ub(llvm::None) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ struct DependenceResult {
|
|||||||
DependenceResult checkMemrefAccessDependence(
|
DependenceResult checkMemrefAccessDependence(
|
||||||
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
|
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
|
||||||
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
|
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
|
||||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
|
SmallVector<DependenceComponent, 2> *dependenceComponents,
|
||||||
bool allowRAR = false);
|
bool allowRAR = false);
|
||||||
|
|
||||||
/// Utility function that returns true if the provided DependenceResult
|
/// Utility function that returns true if the provided DependenceResult
|
||||||
@ -136,7 +136,7 @@ inline bool hasDependence(DependenceResult result) {
|
|||||||
/// [1, maxLoopDepth].
|
/// [1, maxLoopDepth].
|
||||||
void getDependenceComponents(
|
void getDependenceComponents(
|
||||||
AffineForOp forOp, unsigned maxLoopDepth,
|
AffineForOp forOp, unsigned maxLoopDepth,
|
||||||
std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec);
|
std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec);
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
@ -795,9 +795,9 @@ AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
|
|||||||
/// 'cst' contains constraints that connect newly introduced local identifiers
|
/// 'cst' contains constraints that connect newly introduced local identifiers
|
||||||
/// to existing dimensional and symbolic identifiers. See documentation for
|
/// to existing dimensional and symbolic identifiers. See documentation for
|
||||||
/// AffineExprFlattener on how mod's and div's are flattened.
|
/// AffineExprFlattener on how mod's and div's are flattened.
|
||||||
LogicalResult
|
LogicalResult getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
||||||
getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
|
unsigned numSymbols,
|
||||||
llvm::SmallVectorImpl<int64_t> *flattenedExpr,
|
SmallVectorImpl<int64_t> *flattenedExpr,
|
||||||
FlatAffineConstraints *cst = nullptr);
|
FlatAffineConstraints *cst = nullptr);
|
||||||
|
|
||||||
/// Flattens the result expressions of the map to their corresponding flattened
|
/// Flattens the result expressions of the map to their corresponding flattened
|
||||||
@ -810,11 +810,13 @@ getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
|
|||||||
/// method should be used instead of repeatedly calling getFlattenedAffineExpr
|
/// method should be used instead of repeatedly calling getFlattenedAffineExpr
|
||||||
/// since local variables added to deal with div's and mod's will be reused
|
/// since local variables added to deal with div's and mod's will be reused
|
||||||
/// across expressions.
|
/// across expressions.
|
||||||
LogicalResult getFlattenedAffineExprs(
|
LogicalResult
|
||||||
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
|
getFlattenedAffineExprs(AffineMap map,
|
||||||
|
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
|
||||||
FlatAffineConstraints *cst = nullptr);
|
FlatAffineConstraints *cst = nullptr);
|
||||||
LogicalResult getFlattenedAffineExprs(
|
LogicalResult
|
||||||
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
|
getFlattenedAffineExprs(IntegerSet set,
|
||||||
|
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
|
||||||
FlatAffineConstraints *cst = nullptr);
|
FlatAffineConstraints *cst = nullptr);
|
||||||
|
|
||||||
} // end namespace mlir.
|
} // end namespace mlir.
|
||||||
|
@ -30,9 +30,8 @@ namespace mlir {
|
|||||||
|
|
||||||
/// A callable is either a symbol, or an SSA value, that is referenced by a
|
/// A callable is either a symbol, or an SSA value, that is referenced by a
|
||||||
/// call-like operation. This represents the destination of the call.
|
/// call-like operation. This represents the destination of the call.
|
||||||
struct CallInterfaceCallable
|
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value *> {
|
||||||
: public llvm::PointerUnion<SymbolRefAttr, Value *> {
|
using PointerUnion<SymbolRefAttr, Value *>::PointerUnion;
|
||||||
using llvm::PointerUnion<SymbolRefAttr, Value *>::PointerUnion;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "mlir/Analysis/CallInterfaces.h.inc"
|
#include "mlir/Analysis/CallInterfaces.h.inc"
|
||||||
|
@ -56,7 +56,7 @@ protected:
|
|||||||
bool properlyDominates(Block *a, Block *b);
|
bool properlyDominates(Block *a, Block *b);
|
||||||
|
|
||||||
/// A mapping of regions to their base dominator tree.
|
/// A mapping of regions to their base dominator tree.
|
||||||
llvm::DenseMap<Region *, std::unique_ptr<base>> dominanceInfos;
|
DenseMap<Region *, std::unique_ptr<base>> dominanceInfos;
|
||||||
};
|
};
|
||||||
} // end namespace detail
|
} // end namespace detail
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
|||||||
}],
|
}],
|
||||||
/*retTy=*/"LogicalResult",
|
/*retTy=*/"LogicalResult",
|
||||||
/*methodName=*/"inferReturnTypes",
|
/*methodName=*/"inferReturnTypes",
|
||||||
/*args=*/(ins "llvm::Optional<Location>":$location,
|
/*args=*/(ins "Optional<Location>":$location,
|
||||||
"ValueRange":$operands,
|
"ValueRange":$operands,
|
||||||
"ArrayRef<NamedAttribute>":$attributes,
|
"ArrayRef<NamedAttribute>":$attributes,
|
||||||
"RegionRange":$regions,
|
"RegionRange":$regions,
|
||||||
|
@ -50,7 +50,7 @@ void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
|
|||||||
/// Returns the trip count of the loop if it's a constant, None otherwise. This
|
/// Returns the trip count of the loop if it's a constant, None otherwise. This
|
||||||
/// uses affine expression analysis and is able to determine constant trip count
|
/// uses affine expression analysis and is able to determine constant trip count
|
||||||
/// in non-trivial cases.
|
/// in non-trivial cases.
|
||||||
llvm::Optional<uint64_t> getConstantTripCount(AffineForOp forOp);
|
Optional<uint64_t> getConstantTripCount(AffineForOp forOp);
|
||||||
|
|
||||||
/// Returns the greatest known integral divisor of the trip count. Affine
|
/// Returns the greatest known integral divisor of the trip count. Affine
|
||||||
/// expression analysis is used (indirectly through getTripCount), and
|
/// expression analysis is used (indirectly through getTripCount), and
|
||||||
@ -66,8 +66,8 @@ uint64_t getLargestDivisorOfTripCount(AffineForOp forOp);
|
|||||||
///
|
///
|
||||||
/// Emits a note if it encounters a chain of affine.apply and conservatively
|
/// Emits a note if it encounters a chain of affine.apply and conservatively
|
||||||
/// those cases.
|
/// those cases.
|
||||||
llvm::DenseSet<Value *, llvm::DenseMapInfo<Value *>>
|
DenseSet<Value *, DenseMapInfo<Value *>>
|
||||||
getInvariantAccesses(Value *iv, llvm::ArrayRef<Value *> indices);
|
getInvariantAccesses(Value *iv, ArrayRef<Value *> indices);
|
||||||
|
|
||||||
using VectorizableLoopFun = std::function<bool(AffineForOp)>;
|
using VectorizableLoopFun = std::function<bool(AffineForOp)>;
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
|
|||||||
/// 'def' and all its uses have the same shift factor.
|
/// 'def' and all its uses have the same shift factor.
|
||||||
// TODO(mlir-team): extend this to check for memory-based dependence
|
// TODO(mlir-team): extend this to check for memory-based dependence
|
||||||
// violation when we have the support.
|
// violation when we have the support.
|
||||||
bool isInstwiseShiftValid(AffineForOp forOp, llvm::ArrayRef<uint64_t> shifts);
|
bool isInstwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
||||||
|
@ -538,17 +538,17 @@ bool isValidSymbol(Value *value);
|
|||||||
/// dimensional operands
|
/// dimensional operands
|
||||||
/// 4. propagate constant operands and drop them
|
/// 4. propagate constant operands and drop them
|
||||||
void canonicalizeMapAndOperands(AffineMap *map,
|
void canonicalizeMapAndOperands(AffineMap *map,
|
||||||
llvm::SmallVectorImpl<Value *> *operands);
|
SmallVectorImpl<Value *> *operands);
|
||||||
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
|
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
|
||||||
/// for affine maps.
|
/// for affine maps.
|
||||||
void canonicalizeSetAndOperands(IntegerSet *set,
|
void canonicalizeSetAndOperands(IntegerSet *set,
|
||||||
llvm::SmallVectorImpl<Value *> *operands);
|
SmallVectorImpl<Value *> *operands);
|
||||||
|
|
||||||
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
|
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
|
||||||
/// other AffineApplyOps supplying those operands. The operands of the resulting
|
/// other AffineApplyOps supplying those operands. The operands of the resulting
|
||||||
/// AffineApplyOp do not change the length of AffineApplyOp chains.
|
/// AffineApplyOp do not change the length of AffineApplyOp chains.
|
||||||
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
|
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
|
||||||
llvm::ArrayRef<Value *> operands);
|
ArrayRef<Value *> operands);
|
||||||
|
|
||||||
/// Given an affine map `map` and its input `operands`, this method composes
|
/// Given an affine map `map` and its input `operands`, this method composes
|
||||||
/// into `map`, maps of AffineApplyOps whose results are the values in
|
/// into `map`, maps of AffineApplyOps whose results are the values in
|
||||||
@ -558,7 +558,7 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
|
|||||||
/// terminal symbol, i.e., a symbol defined at the top level or a block/function
|
/// terminal symbol, i.e., a symbol defined at the top level or a block/function
|
||||||
/// argument.
|
/// argument.
|
||||||
void fullyComposeAffineMapAndOperands(AffineMap *map,
|
void fullyComposeAffineMapAndOperands(AffineMap *map,
|
||||||
llvm::SmallVectorImpl<Value *> *operands);
|
SmallVectorImpl<Value *> *operands);
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"
|
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"
|
||||||
|
@ -35,8 +35,8 @@ namespace mlir {
|
|||||||
template <class AttrElementT,
|
template <class AttrElementT,
|
||||||
class ElementValueT = typename AttrElementT::ValueType,
|
class ElementValueT = typename AttrElementT::ValueType,
|
||||||
class CalculationT =
|
class CalculationT =
|
||||||
llvm::function_ref<ElementValueT(ElementValueT, ElementValueT)>>
|
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
|
||||||
Attribute constFoldBinaryOp(llvm::ArrayRef<Attribute> operands,
|
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
||||||
const CalculationT &calculate) {
|
const CalculationT &calculate) {
|
||||||
assert(operands.size() == 2 && "binary op takes two operands");
|
assert(operands.size() == 2 && "binary op takes two operands");
|
||||||
if (!operands[0] || !operands[1])
|
if (!operands[0] || !operands[1])
|
||||||
|
@ -443,7 +443,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>,
|
|||||||
KernelDim3 getBlockSizeOperandValues();
|
KernelDim3 getBlockSizeOperandValues();
|
||||||
|
|
||||||
/// Get the SSA values of the kernel arguments.
|
/// Get the SSA values of the kernel arguments.
|
||||||
llvm::iterator_range<Block::args_iterator> getKernelArguments();
|
iterator_range<Block::args_iterator> getKernelArguments();
|
||||||
|
|
||||||
/// Erase the `index`-th kernel argument. Both the entry block argument and
|
/// Erase the `index`-th kernel argument. Both the entry block argument and
|
||||||
/// the operand will be dropped. The block argument must not have any uses.
|
/// the operand will be dropped. The block argument must not have any uses.
|
||||||
|
@ -156,7 +156,7 @@ private:
|
|||||||
/// Get an LLVMType with an llvm type that may cause changes to the underlying
|
/// Get an LLVMType with an llvm type that may cause changes to the underlying
|
||||||
/// llvm context when constructed.
|
/// llvm context when constructed.
|
||||||
static LLVMType getLocked(LLVMDialect *dialect,
|
static LLVMType getLocked(LLVMDialect *dialect,
|
||||||
llvm::function_ref<llvm::Type *()> typeBuilder);
|
function_ref<llvm::Type *()> typeBuilder);
|
||||||
};
|
};
|
||||||
|
|
||||||
///// Ops /////
|
///// Ops /////
|
||||||
|
@ -637,7 +637,7 @@ def LLVM_LLVMFuncOp
|
|||||||
def LLVM_NullOp
|
def LLVM_NullOp
|
||||||
: LLVM_OneResultOp<"mlir.null", [NoSideEffect]>,
|
: LLVM_OneResultOp<"mlir.null", [NoSideEffect]>,
|
||||||
LLVM_Builder<"$res = llvm::ConstantPointerNull::get("
|
LLVM_Builder<"$res = llvm::ConstantPointerNull::get("
|
||||||
" llvm::cast<llvm::PointerType>($_resultType));"> {
|
" cast<llvm::PointerType>($_resultType));"> {
|
||||||
let parser = [{ return parseNullOp(parser, result); }];
|
let parser = [{ return parseNullOp(parser, result); }];
|
||||||
let printer = [{ printNullOp(p, *this); }];
|
let printer = [{ printNullOp(p, *this); }];
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
@ -66,10 +66,10 @@ public:
|
|||||||
// 2. dst in the case of dependencesIntoGraphs.
|
// 2. dst in the case of dependencesIntoGraphs.
|
||||||
Value *indexingView;
|
Value *indexingView;
|
||||||
};
|
};
|
||||||
using LinalgDependences = llvm::SmallVector<LinalgDependenceGraphElem, 8>;
|
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
|
||||||
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
|
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
|
||||||
using dependence_iterator = LinalgDependences::const_iterator;
|
using dependence_iterator = LinalgDependences::const_iterator;
|
||||||
using dependence_range = llvm::iterator_range<dependence_iterator>;
|
using dependence_range = iterator_range<dependence_iterator>;
|
||||||
|
|
||||||
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
|
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
|
||||||
|
|
||||||
|
@ -77,11 +77,11 @@ private:
|
|||||||
|
|
||||||
inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
|
inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
|
||||||
|
|
||||||
Operation *makeLinalgGenericOp(
|
Operation *makeLinalgGenericOp(ArrayRef<IterType> iteratorTypes,
|
||||||
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
ArrayRef<StructuredIndexed> inputs,
|
||||||
ArrayRef<StructuredIndexed> outputs,
|
ArrayRef<StructuredIndexed> outputs,
|
||||||
llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder =
|
function_ref<void(ArrayRef<BlockArgument *>)>
|
||||||
defaultRegionBuilder,
|
regionBuilder = defaultRegionBuilder,
|
||||||
ArrayRef<Value *> otherValues = {},
|
ArrayRef<Value *> otherValues = {},
|
||||||
ArrayRef<Attribute> otherAttributes = {});
|
ArrayRef<Attribute> otherAttributes = {});
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ void macRegionBuilder(ArrayRef<BlockArgument *> args);
|
|||||||
/// with in-place semantics and parallelism.
|
/// with in-place semantics and parallelism.
|
||||||
|
|
||||||
/// Unary pointwise operation (with broadcast) entry point.
|
/// Unary pointwise operation (with broadcast) entry point.
|
||||||
using UnaryPointwiseOpBuilder = llvm::function_ref<Value *(ValueHandle)>;
|
using UnaryPointwiseOpBuilder = function_ref<Value *(ValueHandle)>;
|
||||||
Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
|
Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
|
||||||
StructuredIndexed I, StructuredIndexed O);
|
StructuredIndexed I, StructuredIndexed O);
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
|
|||||||
|
|
||||||
/// Binary pointwise operation (with broadcast) entry point.
|
/// Binary pointwise operation (with broadcast) entry point.
|
||||||
using BinaryPointwiseOpBuilder =
|
using BinaryPointwiseOpBuilder =
|
||||||
llvm::function_ref<Value *(ValueHandle, ValueHandle)>;
|
function_ref<Value *(ValueHandle, ValueHandle)>;
|
||||||
Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
|
Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
|
||||||
StructuredIndexed I1, StructuredIndexed I2,
|
StructuredIndexed I1, StructuredIndexed I2,
|
||||||
StructuredIndexed O);
|
StructuredIndexed O);
|
||||||
|
@ -101,13 +101,13 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
|
|||||||
Query the index of the given input value, or `None` if the value is not
|
Query the index of the given input value, or `None` if the value is not
|
||||||
an input.
|
an input.
|
||||||
}],
|
}],
|
||||||
"llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value *":$view)
|
"Optional<unsigned>", "getIndexOfInput", (ins "Value *":$view)
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<[{
|
InterfaceMethod<[{
|
||||||
Query the index of the given view value, or `None` if the value is not
|
Query the index of the given view value, or `None` if the value is not
|
||||||
an view.
|
an view.
|
||||||
}],
|
}],
|
||||||
"llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value *":$view)
|
"Optional<unsigned>", "getIndexOfOutput", (ins "Value *":$view)
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<[{
|
InterfaceMethod<[{
|
||||||
Query the type of the input view at the given index.
|
Query the type of the input view at the given index.
|
||||||
|
@ -128,7 +128,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
|
|||||||
|
|
||||||
// Get the subset of indexings that are of RangeType.
|
// Get the subset of indexings that are of RangeType.
|
||||||
SmallVector<Value *, 8> getRanges() {
|
SmallVector<Value *, 8> getRanges() {
|
||||||
llvm::SmallVector<Value *, 8> res;
|
SmallVector<Value *, 8> res;
|
||||||
for (auto *operand : indexings())
|
for (auto *operand : indexings())
|
||||||
if (!operand->getType().isa<IndexType>())
|
if (!operand->getType().isa<IndexType>())
|
||||||
res.push_back(operand);
|
res.push_back(operand);
|
||||||
|
@ -83,7 +83,7 @@ public:
|
|||||||
}
|
}
|
||||||
/// Return the index of `view` in the list of input views if found, llvm::None
|
/// Return the index of `view` in the list of input views if found, llvm::None
|
||||||
/// otherwise.
|
/// otherwise.
|
||||||
llvm::Optional<unsigned> getIndexOfInput(Value *view) {
|
Optional<unsigned> getIndexOfInput(Value *view) {
|
||||||
auto it = llvm::find(getInputs(), view);
|
auto it = llvm::find(getInputs(), view);
|
||||||
if (it != getInputs().end())
|
if (it != getInputs().end())
|
||||||
return it - getInputs().begin();
|
return it - getInputs().begin();
|
||||||
@ -104,7 +104,7 @@ public:
|
|||||||
}
|
}
|
||||||
/// Return the index of `view` in the list of output views if found,
|
/// Return the index of `view` in the list of output views if found,
|
||||||
/// llvm::None otherwise.
|
/// llvm::None otherwise.
|
||||||
llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
|
Optional<unsigned> getIndexOfOutput(Value *view) {
|
||||||
auto it = llvm::find(getOutputs(), view);
|
auto it = llvm::find(getOutputs(), view);
|
||||||
if (it != getOutputs().end())
|
if (it != getOutputs().end())
|
||||||
return it - getOutputs().begin();
|
return it - getOutputs().begin();
|
||||||
|
@ -39,7 +39,7 @@ namespace detail {
|
|||||||
// Implementation detail of isProducedByOpOfType avoids the need for explicit
|
// Implementation detail of isProducedByOpOfType avoids the need for explicit
|
||||||
// template instantiations.
|
// template instantiations.
|
||||||
bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value *consumedView,
|
bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value *consumedView,
|
||||||
llvm::function_ref<bool(Operation *)> isaOpType);
|
function_ref<bool(Operation *)> isaOpType);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
// Returns true if the `consumedView` value use in `consumerOp` is produced by
|
// Returns true if the `consumedView` value use in `consumerOp` is produced by
|
||||||
|
@ -62,16 +62,16 @@ public:
|
|||||||
/// directly. In the current implementation it produces loop.for operations.
|
/// directly. In the current implementation it produces loop.for operations.
|
||||||
class LoopNestRangeBuilder {
|
class LoopNestRangeBuilder {
|
||||||
public:
|
public:
|
||||||
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||||
llvm::ArrayRef<edsc::ValueHandle> ranges);
|
ArrayRef<edsc::ValueHandle> ranges);
|
||||||
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||||
llvm::ArrayRef<Value *> ranges);
|
ArrayRef<Value *> ranges);
|
||||||
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||||
llvm::ArrayRef<SubViewOp::Range> ranges);
|
ArrayRef<SubViewOp::Range> ranges);
|
||||||
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::SmallVector<LoopRangeBuilder, 4> loops;
|
SmallVector<LoopRangeBuilder, 4> loops;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace edsc
|
} // namespace edsc
|
||||||
@ -150,7 +150,7 @@ struct TiledLinalgOp {
|
|||||||
/// When non-null, the optional pointer `folder` is used to call into the
|
/// When non-null, the optional pointer `folder` is used to call into the
|
||||||
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
||||||
/// method is called.
|
/// method is called.
|
||||||
llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||||
ArrayRef<Value *> tileSizes,
|
ArrayRef<Value *> tileSizes,
|
||||||
ArrayRef<unsigned> permutation = {},
|
ArrayRef<unsigned> permutation = {},
|
||||||
OperationFolder *folder = nullptr);
|
OperationFolder *folder = nullptr);
|
||||||
@ -170,13 +170,13 @@ llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
|||||||
/// When non-null, the optional pointer `folder` is used to call into the
|
/// When non-null, the optional pointer `folder` is used to call into the
|
||||||
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
||||||
/// method is called.
|
/// method is called.
|
||||||
llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||||
ArrayRef<int64_t> tileSizes,
|
ArrayRef<int64_t> tileSizes,
|
||||||
ArrayRef<unsigned> permutation = {},
|
ArrayRef<unsigned> permutation = {},
|
||||||
OperationFolder *folder = nullptr);
|
OperationFolder *folder = nullptr);
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
llvm::Optional<TiledLinalgOp> tileLinalgOperation(OpBuilder &b, Operation *op,
|
Optional<TiledLinalgOp> tileLinalgOperation(OpBuilder &b, Operation *op,
|
||||||
Args... args) {
|
Args... args) {
|
||||||
return tileLinalgOp(b, cast<LinalgOp>(op), args...);
|
return tileLinalgOp(b, cast<LinalgOp>(op), args...);
|
||||||
}
|
}
|
||||||
@ -198,14 +198,14 @@ struct PromotionInfo {
|
|||||||
///
|
///
|
||||||
/// Returns a list of PromotionInfo which hold the promoted buffer and the
|
/// Returns a list of PromotionInfo which hold the promoted buffer and the
|
||||||
/// full and partial views indexing into the buffer.
|
/// full and partial views indexing into the buffer.
|
||||||
llvm::SmallVector<PromotionInfo, 8>
|
SmallVector<PromotionInfo, 8>
|
||||||
promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
|
promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
|
||||||
bool dynamicBuffers = false, OperationFolder *folder = nullptr);
|
bool dynamicBuffers = false, OperationFolder *folder = nullptr);
|
||||||
|
|
||||||
/// Returns all the operands of `linalgOp` that are not views.
|
/// Returns all the operands of `linalgOp` that are not views.
|
||||||
/// Asserts that these operands are value types to allow transformations like
|
/// Asserts that these operands are value types to allow transformations like
|
||||||
/// tiling to just use the values when cloning `linalgOp`.
|
/// tiling to just use the values when cloning `linalgOp`.
|
||||||
llvm::SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
|
SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
|
||||||
|
|
||||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||||
|
@ -361,7 +361,7 @@ public:
|
|||||||
|
|
||||||
/// Verifies construction invariants and issues errors/warnings.
|
/// Verifies construction invariants and issues errors/warnings.
|
||||||
static LogicalResult verifyConstructionInvariants(
|
static LogicalResult verifyConstructionInvariants(
|
||||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
|
Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||||
Type storageType, Type expressedType, ArrayRef<double> scales,
|
Type storageType, Type expressedType, ArrayRef<double> scales,
|
||||||
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
|
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
|
||||||
int64_t storageTypeMin, int64_t storageTypeMax);
|
int64_t storageTypeMin, int64_t storageTypeMax);
|
||||||
|
@ -85,7 +85,7 @@ public:
|
|||||||
clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
|
clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
|
||||||
clampMinDouble(clampMin), clampMaxDouble(clampMax),
|
clampMinDouble(clampMin), clampMaxDouble(clampMax),
|
||||||
storageBitWidth(storageBitWidth), isSigned(isSigned),
|
storageBitWidth(storageBitWidth), isSigned(isSigned),
|
||||||
roundMode(llvm::APFloat::rmNearestTiesToAway) {}
|
roundMode(APFloat::rmNearestTiesToAway) {}
|
||||||
|
|
||||||
UniformQuantizedValueConverter(double scale, double zeroPoint,
|
UniformQuantizedValueConverter(double scale, double zeroPoint,
|
||||||
APFloat clampMin, APFloat clampMax,
|
APFloat clampMin, APFloat clampMax,
|
||||||
@ -95,7 +95,7 @@ public:
|
|||||||
clampMinDouble(clampMin.convertToDouble()),
|
clampMinDouble(clampMin.convertToDouble()),
|
||||||
clampMaxDouble(clampMax.convertToDouble()),
|
clampMaxDouble(clampMax.convertToDouble()),
|
||||||
storageBitWidth(storageBitWidth), isSigned(isSigned),
|
storageBitWidth(storageBitWidth), isSigned(isSigned),
|
||||||
roundMode(llvm::APFloat::rmNearestTiesToAway) {}
|
roundMode(APFloat::rmNearestTiesToAway) {}
|
||||||
|
|
||||||
virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
|
virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
|
||||||
// This function is a performance critical code path in quantization
|
// This function is a performance critical code path in quantization
|
||||||
@ -154,8 +154,7 @@ private:
|
|||||||
} else {
|
} else {
|
||||||
signlessResult = static_cast<uint8_t>(clamped);
|
signlessResult = static_cast<uint8_t>(clamped);
|
||||||
}
|
}
|
||||||
llvm::APInt result(storageBitWidth, signlessResult);
|
return APInt(storageBitWidth, signlessResult);
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep both APFloat and double versions of the quantization parameters
|
// Keep both APFloat and double versions of the quantization parameters
|
||||||
|
@ -115,7 +115,7 @@ public:
|
|||||||
SmallVectorImpl<SDBMExpr> &inequalities,
|
SmallVectorImpl<SDBMExpr> &inequalities,
|
||||||
SmallVectorImpl<SDBMExpr> &equalities);
|
SmallVectorImpl<SDBMExpr> &equalities);
|
||||||
|
|
||||||
void print(llvm::raw_ostream &os);
|
void print(raw_ostream &os);
|
||||||
void dump();
|
void dump();
|
||||||
|
|
||||||
IntInfty operator()(int i, int j) { return at(i, j); }
|
IntInfty operator()(int i, int j) { return at(i, j); }
|
||||||
@ -198,7 +198,7 @@ private:
|
|||||||
/// temporaries can appear in these expressions. This removes the need to
|
/// temporaries can appear in these expressions. This removes the need to
|
||||||
/// iteratively substitute definitions of the temporaries in the reverse
|
/// iteratively substitute definitions of the temporaries in the reverse
|
||||||
/// conversion.
|
/// conversion.
|
||||||
llvm::DenseMap<unsigned, SDBMExpr> stripeToPoint;
|
DenseMap<unsigned, SDBMExpr> stripeToPoint;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -41,7 +41,7 @@ namespace spirv {
|
|||||||
///
|
///
|
||||||
/// Get the function that can be used to symbolize an enum value.
|
/// Get the function that can be used to symbolize an enum value.
|
||||||
/// template <typename EnumClass>
|
/// template <typename EnumClass>
|
||||||
/// llvm::Optional<EnumClass> (*)(StringRef) symbolizeEnum();
|
/// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
|
#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
|
||||||
|
|
||||||
} // end namespace spirv
|
} // end namespace spirv
|
||||||
|
@ -345,7 +345,7 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
|
|||||||
SmallVectorImpl<Value *> &operands,
|
SmallVectorImpl<Value *> &operands,
|
||||||
unsigned &numDims);
|
unsigned &numDims);
|
||||||
|
|
||||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, SubViewOp::Range &range);
|
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
@ -46,15 +46,15 @@ class VectorType;
|
|||||||
/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
|
/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
|
||||||
/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
|
/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
|
||||||
/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
|
/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
|
||||||
llvm::Optional<llvm::SmallVector<int64_t, 4>>
|
Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
|
||||||
shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape);
|
ArrayRef<int64_t> subShape);
|
||||||
|
|
||||||
/// Computes and returns the multi-dimensional ratio of the shapes of
|
/// Computes and returns the multi-dimensional ratio of the shapes of
|
||||||
/// `superVector` to `subVector`. If integral division is not possible, returns
|
/// `superVector` to `subVector`. If integral division is not possible, returns
|
||||||
/// None.
|
/// None.
|
||||||
/// Assumes and enforces that the VectorTypes have the same elemental type.
|
/// Assumes and enforces that the VectorTypes have the same elemental type.
|
||||||
llvm::Optional<llvm::SmallVector<int64_t, 4>>
|
Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
|
||||||
shapeRatio(VectorType superVectorType, VectorType subVectorType);
|
VectorType subVectorType);
|
||||||
|
|
||||||
/// Constructs a permutation map of invariant memref indices to vector
|
/// Constructs a permutation map of invariant memref indices to vector
|
||||||
/// dimension.
|
/// dimension.
|
||||||
@ -121,9 +121,9 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
|
|||||||
/// Meaning that vector.transfer_read will be responsible of reading the slice
|
/// Meaning that vector.transfer_read will be responsible of reading the slice
|
||||||
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
|
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
|
||||||
///
|
///
|
||||||
AffineMap makePermutationMap(
|
AffineMap
|
||||||
Operation *op, ArrayRef<Value *> indices,
|
makePermutationMap(Operation *op, ArrayRef<Value *> indices,
|
||||||
const llvm::DenseMap<Operation *, unsigned> &loopToVectorDim);
|
const DenseMap<Operation *, unsigned> &loopToVectorDim);
|
||||||
|
|
||||||
namespace matcher {
|
namespace matcher {
|
||||||
|
|
||||||
|
15
third_party/mlir/include/mlir/EDSC/Builders.h
vendored
15
third_party/mlir/include/mlir/EDSC/Builders.h
vendored
@ -78,7 +78,7 @@ private:
|
|||||||
/// Top level OpBuilder.
|
/// Top level OpBuilder.
|
||||||
OpBuilder &builder;
|
OpBuilder &builder;
|
||||||
/// The previous insertion point of the builder.
|
/// The previous insertion point of the builder.
|
||||||
llvm::Optional<OpBuilder::InsertPoint> prevBuilderInsertPoint;
|
Optional<OpBuilder::InsertPoint> prevBuilderInsertPoint;
|
||||||
/// Current location.
|
/// Current location.
|
||||||
Location location;
|
Location location;
|
||||||
/// Parent context we return into.
|
/// Parent context we return into.
|
||||||
@ -178,7 +178,7 @@ public:
|
|||||||
/// The only purpose of this operator is to serve as a sequence point so that
|
/// The only purpose of this operator is to serve as a sequence point so that
|
||||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||||
/// scoped within a LoopBuilder.
|
/// scoped within a LoopBuilder.
|
||||||
void operator()(llvm::function_ref<void(void)> fun = nullptr);
|
void operator()(function_ref<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LoopBuilder() = default;
|
LoopBuilder() = default;
|
||||||
@ -217,7 +217,7 @@ public:
|
|||||||
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||||
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
||||||
|
|
||||||
void operator()(llvm::function_ref<void(void)> fun = nullptr);
|
void operator()(function_ref<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<LoopBuilder, 4> loops;
|
SmallVector<LoopBuilder, 4> loops;
|
||||||
@ -228,13 +228,12 @@ private:
|
|||||||
/// loop.for.
|
/// loop.for.
|
||||||
class LoopNestBuilder {
|
class LoopNestBuilder {
|
||||||
public:
|
public:
|
||||||
LoopNestBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
LoopNestBuilder(ArrayRef<edsc::ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||||
ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
|
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps);
|
||||||
ArrayRef<ValueHandle> steps);
|
|
||||||
void operator()(std::function<void(void)> fun = nullptr);
|
void operator()(std::function<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::SmallVector<LoopBuilder, 4> loops;
|
SmallVector<LoopBuilder, 4> loops;
|
||||||
};
|
};
|
||||||
|
|
||||||
// This class exists solely to handle the C++ vexing parse case when
|
// This class exists solely to handle the C++ vexing parse case when
|
||||||
@ -264,7 +263,7 @@ public:
|
|||||||
/// The only purpose of this operator is to serve as a sequence point so that
|
/// The only purpose of this operator is to serve as a sequence point so that
|
||||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||||
/// scoped within a BlockBuilder.
|
/// scoped within a BlockBuilder.
|
||||||
void operator()(llvm::function_ref<void(void)> fun = nullptr);
|
void operator()(function_ref<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BlockBuilder(BlockBuilder &) = delete;
|
BlockBuilder(BlockBuilder &) = delete;
|
||||||
|
8
third_party/mlir/include/mlir/EDSC/Helpers.h
vendored
8
third_party/mlir/include/mlir/EDSC/Helpers.h
vendored
@ -137,12 +137,12 @@ public:
|
|||||||
TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
|
TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
|
||||||
return TemplatedIndexedValue(base, index).append(indices...);
|
return TemplatedIndexedValue(base, index).append(indices...);
|
||||||
}
|
}
|
||||||
TemplatedIndexedValue operator()(llvm::ArrayRef<ValueHandle> indices) {
|
TemplatedIndexedValue operator()(ArrayRef<ValueHandle> indices) {
|
||||||
return TemplatedIndexedValue(base, indices);
|
return TemplatedIndexedValue(base, indices);
|
||||||
}
|
}
|
||||||
TemplatedIndexedValue operator()(llvm::ArrayRef<IndexHandle> indices) {
|
TemplatedIndexedValue operator()(ArrayRef<IndexHandle> indices) {
|
||||||
return TemplatedIndexedValue(
|
return TemplatedIndexedValue(
|
||||||
base, llvm::ArrayRef<ValueHandle>(indices.begin(), indices.end()));
|
base, ArrayRef<ValueHandle>(indices.begin(), indices.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emits a `store`.
|
/// Emits a `store`.
|
||||||
@ -215,7 +215,7 @@ private:
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
ValueHandle base;
|
ValueHandle base;
|
||||||
llvm::SmallVector<ValueHandle, 8> indices;
|
SmallVector<ValueHandle, 8> indices;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Operator overloadings.
|
/// Operator overloadings.
|
||||||
|
@ -106,14 +106,14 @@ public:
|
|||||||
values.append(vals.begin(), vals.end());
|
values.append(vals.begin(), vals.end());
|
||||||
}
|
}
|
||||||
ValueHandleArray(ArrayRef<index_t> vals) {
|
ValueHandleArray(ArrayRef<index_t> vals) {
|
||||||
llvm::SmallVector<IndexHandle, 8> tmp(vals.begin(), vals.end());
|
SmallVector<IndexHandle, 8> tmp(vals.begin(), vals.end());
|
||||||
values.append(tmp.begin(), tmp.end());
|
values.append(tmp.begin(), tmp.end());
|
||||||
}
|
}
|
||||||
operator ArrayRef<Value *>() { return values; }
|
operator ArrayRef<Value *>() { return values; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ValueHandleArray() = default;
|
ValueHandleArray() = default;
|
||||||
llvm::SmallVector<Value *, 8> values;
|
SmallVector<Value *, 8> values;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T> inline T unpack(T value) { return value; }
|
template <typename T> inline T unpack(T value) { return value; }
|
||||||
|
@ -50,7 +50,7 @@ public:
|
|||||||
std::unique_ptr<llvm::MemoryBuffer> getObject(const llvm::Module *M) override;
|
std::unique_ptr<llvm::MemoryBuffer> getObject(const llvm::Module *M) override;
|
||||||
|
|
||||||
/// Dump cached object to output file `filename`.
|
/// Dump cached object to output file `filename`.
|
||||||
void dumpToObjectFile(llvm::StringRef filename);
|
void dumpToObjectFile(StringRef filename);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::StringMap<std::unique_ptr<llvm::MemoryBuffer>> cachedObjects;
|
llvm::StringMap<std::unique_ptr<llvm::MemoryBuffer>> cachedObjects;
|
||||||
@ -103,7 +103,7 @@ public:
|
|||||||
static bool setupTargetTriple(llvm::Module *llvmModule);
|
static bool setupTargetTriple(llvm::Module *llvmModule);
|
||||||
|
|
||||||
/// Dump object code to output file `filename`.
|
/// Dump object code to output file `filename`.
|
||||||
void dumpToObjectFile(llvm::StringRef filename);
|
void dumpToObjectFile(StringRef filename);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Ordering of llvmContext and jit is important for destruction purposes: the
|
// Ordering of llvmContext and jit is important for destruction purposes: the
|
||||||
@ -124,7 +124,7 @@ llvm::Error ExecutionEngine::invoke(StringRef name, Args &... args) {
|
|||||||
return expectedFPtr.takeError();
|
return expectedFPtr.takeError();
|
||||||
auto fptr = *expectedFPtr;
|
auto fptr = *expectedFPtr;
|
||||||
|
|
||||||
llvm::SmallVector<void *, 8> packedArgs{static_cast<void *>(&args)...};
|
SmallVector<void *, 8> packedArgs{static_cast<void *>(&args)...};
|
||||||
(*fptr)(packedArgs.data());
|
(*fptr)(packedArgs.data());
|
||||||
|
|
||||||
return llvm::Error::success();
|
return llvm::Error::success();
|
||||||
|
@ -272,7 +272,7 @@ AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
|
|||||||
/// flattened.
|
/// flattened.
|
||||||
bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
||||||
unsigned numSymbols,
|
unsigned numSymbols,
|
||||||
llvm::SmallVectorImpl<int64_t> *flattenedExpr);
|
SmallVectorImpl<int64_t> *flattenedExpr);
|
||||||
|
|
||||||
/// Flattens the result expressions of the map to their corresponding flattened
|
/// Flattens the result expressions of the map to their corresponding flattened
|
||||||
/// forms and set in 'flattenedExprs'. Returns true on success or false
|
/// forms and set in 'flattenedExprs'. Returns true on success or false
|
||||||
@ -282,9 +282,9 @@ bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
|||||||
/// repeatedly calling getFlattenedAffineExpr since local variables added to
|
/// repeatedly calling getFlattenedAffineExpr since local variables added to
|
||||||
/// deal with div's and mod's will be reused across expressions.
|
/// deal with div's and mod's will be reused across expressions.
|
||||||
bool getFlattenedAffineExprs(
|
bool getFlattenedAffineExprs(
|
||||||
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
|
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
|
||||||
bool getFlattenedAffineExprs(
|
bool getFlattenedAffineExprs(
|
||||||
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs);
|
IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
template <int N> void bindDims(MLIRContext *ctx) {}
|
template <int N> void bindDims(MLIRContext *ctx) {}
|
||||||
|
2
third_party/mlir/include/mlir/IR/AffineMap.h
vendored
2
third_party/mlir/include/mlir/IR/AffineMap.h
vendored
@ -227,7 +227,7 @@ AffineMap inversePermutation(AffineMap map);
|
|||||||
/// ```{.mlir}
|
/// ```{.mlir}
|
||||||
/// (i, j, k) -> (i, k, k, j, i, j)
|
/// (i, j, k) -> (i, k, k, j, i, j)
|
||||||
/// ```
|
/// ```
|
||||||
AffineMap concatAffineMaps(llvm::ArrayRef<AffineMap> maps);
|
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
|
||||||
|
|
||||||
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
|
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
|
||||||
map.print(os);
|
map.print(os);
|
||||||
|
24
third_party/mlir/include/mlir/IR/Attributes.h
vendored
24
third_party/mlir/include/mlir/IR/Attributes.h
vendored
@ -583,16 +583,14 @@ public:
|
|||||||
/// Generates a new ElementsAttr by mapping each int value to a new
|
/// Generates a new ElementsAttr by mapping each int value to a new
|
||||||
/// underlying APInt. The new values can represent either a integer or float.
|
/// underlying APInt. The new values can represent either a integer or float.
|
||||||
/// This ElementsAttr should contain integers.
|
/// This ElementsAttr should contain integers.
|
||||||
ElementsAttr
|
ElementsAttr mapValues(Type newElementType,
|
||||||
mapValues(Type newElementType,
|
function_ref<APInt(const APInt &)> mapping) const;
|
||||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
|
||||||
|
|
||||||
/// Generates a new ElementsAttr by mapping each float value to a new
|
/// Generates a new ElementsAttr by mapping each float value to a new
|
||||||
/// underlying APInt. The new values can represent either a integer or float.
|
/// underlying APInt. The new values can represent either a integer or float.
|
||||||
/// This ElementsAttr should contain floats.
|
/// This ElementsAttr should contain floats.
|
||||||
ElementsAttr
|
ElementsAttr mapValues(Type newElementType,
|
||||||
mapValues(Type newElementType,
|
function_ref<APInt(const APFloat &)> mapping) const;
|
||||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(Attribute attr) {
|
static bool classof(Attribute attr) {
|
||||||
@ -921,16 +919,15 @@ public:
|
|||||||
/// Generates a new DenseElementsAttr by mapping each int value to a new
|
/// Generates a new DenseElementsAttr by mapping each int value to a new
|
||||||
/// underlying APInt. The new values can represent either a integer or float.
|
/// underlying APInt. The new values can represent either a integer or float.
|
||||||
/// This underlying type must be an DenseIntElementsAttr.
|
/// This underlying type must be an DenseIntElementsAttr.
|
||||||
DenseElementsAttr
|
DenseElementsAttr mapValues(Type newElementType,
|
||||||
mapValues(Type newElementType,
|
function_ref<APInt(const APInt &)> mapping) const;
|
||||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
|
||||||
|
|
||||||
/// Generates a new DenseElementsAttr by mapping each float value to a new
|
/// Generates a new DenseElementsAttr by mapping each float value to a new
|
||||||
/// underlying APInt. the new values can represent either a integer or float.
|
/// underlying APInt. the new values can represent either a integer or float.
|
||||||
/// This underlying type must be an DenseFPElementsAttr.
|
/// This underlying type must be an DenseFPElementsAttr.
|
||||||
DenseElementsAttr
|
DenseElementsAttr
|
||||||
mapValues(Type newElementType,
|
mapValues(Type newElementType,
|
||||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
function_ref<APInt(const APFloat &)> mapping) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Return the raw storage data held by this attribute.
|
/// Return the raw storage data held by this attribute.
|
||||||
@ -993,7 +990,7 @@ public:
|
|||||||
/// constructing the DenseElementsAttr given the new element type.
|
/// constructing the DenseElementsAttr given the new element type.
|
||||||
DenseElementsAttr
|
DenseElementsAttr
|
||||||
mapValues(Type newElementType,
|
mapValues(Type newElementType,
|
||||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
function_ref<APInt(const APFloat &)> mapping) const;
|
||||||
|
|
||||||
/// Iterator access to the float element values.
|
/// Iterator access to the float element values.
|
||||||
iterator begin() const { return float_value_begin(); }
|
iterator begin() const { return float_value_begin(); }
|
||||||
@ -1029,9 +1026,8 @@ public:
|
|||||||
|
|
||||||
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
||||||
/// constructing the DenseElementsAttr given the new element type.
|
/// constructing the DenseElementsAttr given the new element type.
|
||||||
DenseElementsAttr
|
DenseElementsAttr mapValues(Type newElementType,
|
||||||
mapValues(Type newElementType,
|
function_ref<APInt(const APInt &)> mapping) const;
|
||||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
|
||||||
|
|
||||||
/// Iterator access to the integer element values.
|
/// Iterator access to the integer element values.
|
||||||
iterator begin() const { return raw_int_begin(); }
|
iterator begin() const { return raw_int_begin(); }
|
||||||
|
10
third_party/mlir/include/mlir/IR/Block.h
vendored
10
third_party/mlir/include/mlir/IR/Block.h
vendored
@ -89,7 +89,7 @@ public:
|
|||||||
BlockArgument *addArgument(Type type);
|
BlockArgument *addArgument(Type type);
|
||||||
|
|
||||||
/// Add one argument to the argument list for each type specified in the list.
|
/// Add one argument to the argument list for each type specified in the list.
|
||||||
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
|
iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
|
||||||
|
|
||||||
/// Erase the argument at 'index' and remove it from the argument list. If
|
/// Erase the argument at 'index' and remove it from the argument list. If
|
||||||
/// 'updatePredTerms' is set to true, this argument is also removed from the
|
/// 'updatePredTerms' is set to true, this argument is also removed from the
|
||||||
@ -175,7 +175,7 @@ public:
|
|||||||
template <typename OpT>
|
template <typename OpT>
|
||||||
class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>,
|
class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>,
|
||||||
OpT (*)(Operation &)> {
|
OpT (*)(Operation &)> {
|
||||||
static OpT unwrap(Operation &op) { return llvm::cast<OpT>(op); }
|
static OpT unwrap(Operation &op) { return cast<OpT>(op); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using reference = OpT;
|
using reference = OpT;
|
||||||
@ -191,7 +191,7 @@ public:
|
|||||||
|
|
||||||
/// Return an iterator range over the operations within this block that are of
|
/// Return an iterator range over the operations within this block that are of
|
||||||
/// 'OpT'.
|
/// 'OpT'.
|
||||||
template <typename OpT> llvm::iterator_range<op_iterator<OpT>> getOps() {
|
template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
|
||||||
auto endIt = end();
|
auto endIt = end();
|
||||||
return {op_filter_iterator<OpT>(begin(), endIt),
|
return {op_filter_iterator<OpT>(begin(), endIt),
|
||||||
op_filter_iterator<OpT>(endIt, endIt)};
|
op_filter_iterator<OpT>(endIt, endIt)};
|
||||||
@ -205,7 +205,7 @@ public:
|
|||||||
|
|
||||||
/// Return an iterator range over the operation within this block excluding
|
/// Return an iterator range over the operation within this block excluding
|
||||||
/// the terminator operation at the end.
|
/// the terminator operation at the end.
|
||||||
llvm::iterator_range<iterator> without_terminator() {
|
iterator_range<iterator> without_terminator() {
|
||||||
if (begin() == end())
|
if (begin() == end())
|
||||||
return {begin(), end()};
|
return {begin(), end()};
|
||||||
auto endIt = --end();
|
auto endIt = --end();
|
||||||
@ -230,7 +230,7 @@ public:
|
|||||||
return pred_iterator((BlockOperand *)getFirstUse());
|
return pred_iterator((BlockOperand *)getFirstUse());
|
||||||
}
|
}
|
||||||
pred_iterator pred_end() { return pred_iterator(nullptr); }
|
pred_iterator pred_end() { return pred_iterator(nullptr); }
|
||||||
llvm::iterator_range<pred_iterator> getPredecessors() {
|
iterator_range<pred_iterator> getPredecessors() {
|
||||||
return {pred_begin(), pred_end()};
|
return {pred_begin(), pred_end()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ private:
|
|||||||
return it != valueMap.end() ? static_cast<T *>(it->second) : value;
|
return it != valueMap.end() ? static_cast<T *>(it->second) : value;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::DenseMap<IRObjectWithUseList *, IRObjectWithUseList *> valueMap;
|
DenseMap<IRObjectWithUseList *, IRObjectWithUseList *> valueMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
20
third_party/mlir/include/mlir/IR/Diagnostics.h
vendored
20
third_party/mlir/include/mlir/IR/Diagnostics.h
vendored
@ -239,10 +239,10 @@ public:
|
|||||||
Diagnostic &operator<<(OperationName val);
|
Diagnostic &operator<<(OperationName val);
|
||||||
|
|
||||||
/// Stream in a range.
|
/// Stream in a range.
|
||||||
template <typename T> Diagnostic &operator<<(llvm::iterator_range<T> range) {
|
template <typename T> Diagnostic &operator<<(iterator_range<T> range) {
|
||||||
return appendRange(range);
|
return appendRange(range);
|
||||||
}
|
}
|
||||||
template <typename T> Diagnostic &operator<<(llvm::ArrayRef<T> range) {
|
template <typename T> Diagnostic &operator<<(ArrayRef<T> range) {
|
||||||
return appendRange(range);
|
return appendRange(range);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -277,16 +277,16 @@ public:
|
|||||||
/// Attaches a note to this diagnostic. A new location may be optionally
|
/// Attaches a note to this diagnostic. A new location may be optionally
|
||||||
/// provided, if not, then the location defaults to the one specified for this
|
/// provided, if not, then the location defaults to the one specified for this
|
||||||
/// diagnostic. Notes may not be attached to other notes.
|
/// diagnostic. Notes may not be attached to other notes.
|
||||||
Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None);
|
Diagnostic &attachNote(Optional<Location> noteLoc = llvm::None);
|
||||||
|
|
||||||
using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
|
using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
|
||||||
using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
|
using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
|
||||||
|
|
||||||
/// Returns the notes held by this diagnostic.
|
/// Returns the notes held by this diagnostic.
|
||||||
llvm::iterator_range<note_iterator> getNotes() {
|
iterator_range<note_iterator> getNotes() {
|
||||||
return {notes.begin(), notes.end()};
|
return {notes.begin(), notes.end()};
|
||||||
}
|
}
|
||||||
llvm::iterator_range<const_note_iterator> getNotes() const {
|
iterator_range<const_note_iterator> getNotes() const {
|
||||||
return {notes.begin(), notes.end()};
|
return {notes.begin(), notes.end()};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,7 +360,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Attaches a note to this diagnostic.
|
/// Attaches a note to this diagnostic.
|
||||||
Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None) {
|
Diagnostic &attachNote(Optional<Location> noteLoc = llvm::None) {
|
||||||
assert(isActive() && "diagnostic not active");
|
assert(isActive() && "diagnostic not active");
|
||||||
return impl->attachNote(noteLoc);
|
return impl->attachNote(noteLoc);
|
||||||
}
|
}
|
||||||
@ -394,7 +394,7 @@ private:
|
|||||||
DiagnosticEngine *owner = nullptr;
|
DiagnosticEngine *owner = nullptr;
|
||||||
|
|
||||||
/// The raw diagnostic that is inflight to be reported.
|
/// The raw diagnostic that is inflight to be reported.
|
||||||
llvm::Optional<Diagnostic> impl;
|
Optional<Diagnostic> impl;
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -551,7 +551,7 @@ struct SourceMgrDiagnosticHandlerImpl;
|
|||||||
class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler {
|
class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler {
|
||||||
public:
|
public:
|
||||||
SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
|
SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
|
||||||
llvm::raw_ostream &os);
|
raw_ostream &os);
|
||||||
SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx);
|
SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx);
|
||||||
~SourceMgrDiagnosticHandler();
|
~SourceMgrDiagnosticHandler();
|
||||||
|
|
||||||
@ -570,7 +570,7 @@ protected:
|
|||||||
llvm::SourceMgr &mgr;
|
llvm::SourceMgr &mgr;
|
||||||
|
|
||||||
/// The output stream to use when printing diagnostics.
|
/// The output stream to use when printing diagnostics.
|
||||||
llvm::raw_ostream &os;
|
raw_ostream &os;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Convert a location into the given memory buffer into an SMLoc.
|
/// Convert a location into the given memory buffer into an SMLoc.
|
||||||
@ -597,7 +597,7 @@ struct SourceMgrDiagnosticVerifierHandlerImpl;
|
|||||||
class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
|
class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
|
||||||
public:
|
public:
|
||||||
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
|
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
|
||||||
llvm::raw_ostream &out);
|
raw_ostream &out);
|
||||||
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
|
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
|
||||||
~SourceMgrDiagnosticVerifierHandler();
|
~SourceMgrDiagnosticVerifierHandler();
|
||||||
|
|
||||||
|
2
third_party/mlir/include/mlir/IR/Function.h
vendored
2
third_party/mlir/include/mlir/IR/Function.h
vendored
@ -50,7 +50,7 @@ public:
|
|||||||
static FuncOp create(Location location, StringRef name, FunctionType type,
|
static FuncOp create(Location location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs = {});
|
ArrayRef<NamedAttribute> attrs = {});
|
||||||
static FuncOp create(Location location, StringRef name, FunctionType type,
|
static FuncOp create(Location location, StringRef name, FunctionType type,
|
||||||
llvm::iterator_range<dialect_attr_iterator> attrs);
|
iterator_range<dialect_attr_iterator> attrs);
|
||||||
static FuncOp create(Location location, StringRef name, FunctionType type,
|
static FuncOp create(Location location, StringRef name, FunctionType type,
|
||||||
ArrayRef<NamedAttribute> attrs,
|
ArrayRef<NamedAttribute> attrs,
|
||||||
ArrayRef<NamedAttributeList> argAttrs);
|
ArrayRef<NamedAttributeList> argAttrs);
|
||||||
|
@ -55,7 +55,7 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
|
|||||||
/// function arguments and results, VariadicFlag indicates whether the function
|
/// function arguments and results, VariadicFlag indicates whether the function
|
||||||
/// should have variadic arguments; in case of error, it may populate the last
|
/// should have variadic arguments; in case of error, it may populate the last
|
||||||
/// argument with a message.
|
/// argument with a message.
|
||||||
using FuncTypeBuilder = llvm::function_ref<Type(
|
using FuncTypeBuilder = function_ref<Type(
|
||||||
Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
|
Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
|
||||||
|
|
||||||
/// Parses a function signature using `parser`. The `allowVariadic` argument
|
/// Parses a function signature using `parser`. The `allowVariadic` argument
|
||||||
|
@ -191,7 +191,7 @@ public:
|
|||||||
using args_iterator = Block::args_iterator;
|
using args_iterator = Block::args_iterator;
|
||||||
args_iterator args_begin() { return front().args_begin(); }
|
args_iterator args_begin() { return front().args_begin(); }
|
||||||
args_iterator args_end() { return front().args_end(); }
|
args_iterator args_end() { return front().args_end(); }
|
||||||
llvm::iterator_range<args_iterator> getArguments() {
|
iterator_range<args_iterator> getArguments() {
|
||||||
return {args_begin(), args_end()};
|
return {args_begin(), args_end()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ public:
|
|||||||
|
|
||||||
/// Walk all of the AffineExpr's in this set's constraints. Each node in an
|
/// Walk all of the AffineExpr's in this set's constraints. Each node in an
|
||||||
/// expression tree is visited in postorder.
|
/// expression tree is visited in postorder.
|
||||||
void walkExprs(llvm::function_ref<void(AffineExpr)> callback) const;
|
void walkExprs(function_ref<void(AffineExpr)> callback) const;
|
||||||
|
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
void dump() const;
|
void dump() const;
|
||||||
|
2
third_party/mlir/include/mlir/IR/Module.h
vendored
2
third_party/mlir/include/mlir/IR/Module.h
vendored
@ -81,7 +81,7 @@ public:
|
|||||||
|
|
||||||
/// This returns a range of operations of the given type 'T' held within the
|
/// This returns a range of operations of the given type 'T' held within the
|
||||||
/// module.
|
/// module.
|
||||||
template <typename T> llvm::iterator_range<Block::op_iterator<T>> getOps() {
|
template <typename T> iterator_range<Block::op_iterator<T>> getOps() {
|
||||||
return getBody()->getOps<T>();
|
return getBody()->getOps<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
13
third_party/mlir/include/mlir/IR/OpDefinition.h
vendored
13
third_party/mlir/include/mlir/IR/OpDefinition.h
vendored
@ -54,7 +54,7 @@ public:
|
|||||||
explicit operator bool() const { return failed(*this); }
|
explicit operator bool() const { return failed(*this); }
|
||||||
};
|
};
|
||||||
/// This class implements `Optional` functionality for ParseResult. We don't
|
/// This class implements `Optional` functionality for ParseResult. We don't
|
||||||
/// directly use llvm::Optional here, because it provides an implicit conversion
|
/// directly use Optional here, because it provides an implicit conversion
|
||||||
/// to 'bool' which we want to avoid. This class is used to implement tri-state
|
/// to 'bool' which we want to avoid. This class is used to implement tri-state
|
||||||
/// 'parseOptional' functions that may have a failure mode when parsing that
|
/// 'parseOptional' functions that may have a failure mode when parsing that
|
||||||
/// shouldn't be attributed to "not present".
|
/// shouldn't be attributed to "not present".
|
||||||
@ -85,9 +85,8 @@ namespace impl {
|
|||||||
/// region's only block if it does not have a terminator already. If the region
|
/// region's only block if it does not have a terminator already. If the region
|
||||||
/// is empty, insert a new block first. `buildTerminatorOp` should return the
|
/// is empty, insert a new block first. `buildTerminatorOp` should return the
|
||||||
/// terminator operation to insert.
|
/// terminator operation to insert.
|
||||||
void ensureRegionTerminator(
|
void ensureRegionTerminator(Region ®ion, Location loc,
|
||||||
Region ®ion, Location loc,
|
function_ref<Operation *()> buildTerminatorOp);
|
||||||
llvm::function_ref<Operation *()> buildTerminatorOp);
|
|
||||||
/// Templated version that fills the generates the provided operation type.
|
/// Templated version that fills the generates the provided operation type.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
void ensureRegionTerminator(Region ®ion, Builder &builder, Location loc) {
|
void ensureRegionTerminator(Region ®ion, Builder &builder, Location loc) {
|
||||||
@ -258,8 +257,8 @@ inline bool operator!=(OpState lhs, OpState rhs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// This class represents a single result from folding an operation.
|
/// This class represents a single result from folding an operation.
|
||||||
class OpFoldResult : public llvm::PointerUnion<Attribute, Value *> {
|
class OpFoldResult : public PointerUnion<Attribute, Value *> {
|
||||||
using llvm::PointerUnion<Attribute, Value *>::PointerUnion;
|
using PointerUnion<Attribute, Value *>::PointerUnion;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This template defines the foldHook as used by AbstractOperation.
|
/// This template defines the foldHook as used by AbstractOperation.
|
||||||
@ -1142,7 +1141,7 @@ private:
|
|||||||
/// };
|
/// };
|
||||||
/// template <typename OpT> class Model {
|
/// template <typename OpT> class Model {
|
||||||
/// unsigned getNumInputs(Operation *op) final {
|
/// unsigned getNumInputs(Operation *op) final {
|
||||||
/// return llvm::cast<OpT>(op).getNumInputs();
|
/// return cast<OpT>(op).getNumInputs();
|
||||||
/// }
|
/// }
|
||||||
/// };
|
/// };
|
||||||
/// };
|
/// };
|
||||||
|
4
third_party/mlir/include/mlir/IR/Operation.h
vendored
4
third_party/mlir/include/mlir/IR/Operation.h
vendored
@ -132,7 +132,7 @@ public:
|
|||||||
template <typename OpTy> OpTy getParentOfType() {
|
template <typename OpTy> OpTy getParentOfType() {
|
||||||
auto *op = this;
|
auto *op = this;
|
||||||
while ((op = op->getParentOp()))
|
while ((op = op->getParentOp()))
|
||||||
if (auto parentOp = llvm::dyn_cast<OpTy>(op))
|
if (auto parentOp = dyn_cast<OpTy>(op))
|
||||||
return parentOp;
|
return parentOp;
|
||||||
return OpTy();
|
return OpTy();
|
||||||
}
|
}
|
||||||
@ -339,7 +339,7 @@ public:
|
|||||||
// Allow access to the constructor.
|
// Allow access to the constructor.
|
||||||
friend Operation;
|
friend Operation;
|
||||||
};
|
};
|
||||||
using dialect_attr_range = llvm::iterator_range<dialect_attr_iterator>;
|
using dialect_attr_range = iterator_range<dialect_attr_iterator>;
|
||||||
|
|
||||||
/// Return a range corresponding to the dialect attributes for this operation.
|
/// Return a range corresponding to the dialect attributes for this operation.
|
||||||
dialect_attr_range getDialectAttrs() {
|
dialect_attr_range getDialectAttrs() {
|
||||||
|
@ -212,7 +212,7 @@ private:
|
|||||||
class OperationName {
|
class OperationName {
|
||||||
public:
|
public:
|
||||||
using RepresentationUnion =
|
using RepresentationUnion =
|
||||||
llvm::PointerUnion<Identifier, const AbstractOperation *>;
|
PointerUnion<Identifier, const AbstractOperation *>;
|
||||||
|
|
||||||
OperationName(AbstractOperation *op) : representation(op) {}
|
OperationName(AbstractOperation *op) : representation(op) {}
|
||||||
OperationName(StringRef name, MLIRContext *context);
|
OperationName(StringRef name, MLIRContext *context);
|
||||||
@ -511,7 +511,7 @@ public:
|
|||||||
private:
|
private:
|
||||||
/// Elide large elements attributes if the number of elements is larger than
|
/// Elide large elements attributes if the number of elements is larger than
|
||||||
/// the upper limit.
|
/// the upper limit.
|
||||||
llvm::Optional<int64_t> elementsAttrElementLimit;
|
Optional<int64_t> elementsAttrElementLimit;
|
||||||
|
|
||||||
/// Print debug information.
|
/// Print debug information.
|
||||||
bool printDebugInfoFlag : 1;
|
bool printDebugInfoFlag : 1;
|
||||||
@ -616,9 +616,8 @@ private:
|
|||||||
/// parameter.
|
/// parameter.
|
||||||
class ValueRange final
|
class ValueRange final
|
||||||
: public detail::indexed_accessor_range_base<
|
: public detail::indexed_accessor_range_base<
|
||||||
ValueRange,
|
ValueRange, PointerUnion<Value *const *, OpOperand *, OpResult *>,
|
||||||
llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
|
Value *, Value *, Value *> {
|
||||||
Value *, Value *> {
|
|
||||||
public:
|
public:
|
||||||
using RangeBaseT::RangeBaseT;
|
using RangeBaseT::RangeBaseT;
|
||||||
|
|
||||||
@ -646,7 +645,7 @@ public:
|
|||||||
private:
|
private:
|
||||||
/// The type representing the owner of this range. This is either a list of
|
/// The type representing the owner of this range. This is either a list of
|
||||||
/// values, operands, or results.
|
/// values, operands, or results.
|
||||||
using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
|
using OwnerT = PointerUnion<Value *const *, OpOperand *, OpResult *>;
|
||||||
|
|
||||||
/// See `detail::indexed_accessor_range_base` for details.
|
/// See `detail::indexed_accessor_range_base` for details.
|
||||||
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
|
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
|
||||||
|
10
third_party/mlir/include/mlir/IR/PatternMatch.h
vendored
10
third_party/mlir/include/mlir/IR/PatternMatch.h
vendored
@ -202,7 +202,7 @@ protected:
|
|||||||
|
|
||||||
/// A list of the potential operations that may be generated when rewriting
|
/// A list of the potential operations that may be generated when rewriting
|
||||||
/// an op with this pattern.
|
/// an op with this pattern.
|
||||||
llvm::SmallVector<OperationName, 2> generatedOps;
|
SmallVector<OperationName, 2> generatedOps;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// OpRewritePattern is a wrapper around RewritePattern that allows for
|
/// OpRewritePattern is a wrapper around RewritePattern that allows for
|
||||||
@ -217,17 +217,17 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
|
|||||||
/// Wrappers around the RewritePattern methods that pass the derived op type.
|
/// Wrappers around the RewritePattern methods that pass the derived op type.
|
||||||
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
|
rewrite(cast<SourceOp>(op), std::move(state), rewriter);
|
||||||
}
|
}
|
||||||
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
|
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
|
||||||
rewrite(llvm::cast<SourceOp>(op), rewriter);
|
rewrite(cast<SourceOp>(op), rewriter);
|
||||||
}
|
}
|
||||||
PatternMatchResult match(Operation *op) const final {
|
PatternMatchResult match(Operation *op) const final {
|
||||||
return match(llvm::cast<SourceOp>(op));
|
return match(cast<SourceOp>(op));
|
||||||
}
|
}
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
|
return matchAndRewrite(cast<SourceOp>(op), rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rewrite and Match methods that operate on the SourceOp type. These must be
|
/// Rewrite and Match methods that operate on the SourceOp type. These must be
|
||||||
|
9
third_party/mlir/include/mlir/IR/Region.h
vendored
9
third_party/mlir/include/mlir/IR/Region.h
vendored
@ -117,7 +117,7 @@ public:
|
|||||||
/// Emit errors if `noteLoc` is provided; this location is used to point
|
/// Emit errors if `noteLoc` is provided; this location is used to point
|
||||||
/// to the operation containing the region, the actual error is reported at
|
/// to the operation containing the region, the actual error is reported at
|
||||||
/// the operation with an offending use.
|
/// the operation with an offending use.
|
||||||
bool isIsolatedFromAbove(llvm::Optional<Location> noteLoc = llvm::None);
|
bool isIsolatedFromAbove(Optional<Location> noteLoc = llvm::None);
|
||||||
|
|
||||||
/// Drop all operand uses from operations within this region, which is
|
/// Drop all operand uses from operations within this region, which is
|
||||||
/// an essential step in breaking cyclic dependences between references when
|
/// an essential step in breaking cyclic dependences between references when
|
||||||
@ -150,7 +150,7 @@ public:
|
|||||||
/// depends on Graphviz to generate the graph.
|
/// depends on Graphviz to generate the graph.
|
||||||
/// This function is defined in ViewRegionGraph and only works with that
|
/// This function is defined in ViewRegionGraph and only works with that
|
||||||
/// target linked.
|
/// target linked.
|
||||||
void viewGraph(const llvm::Twine ®ionName);
|
void viewGraph(const Twine ®ionName);
|
||||||
void viewGraph();
|
void viewGraph();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -167,12 +167,11 @@ private:
|
|||||||
/// parameter.
|
/// parameter.
|
||||||
class RegionRange
|
class RegionRange
|
||||||
: public detail::indexed_accessor_range_base<
|
: public detail::indexed_accessor_range_base<
|
||||||
RegionRange,
|
RegionRange, PointerUnion<Region *, const std::unique_ptr<Region> *>,
|
||||||
llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>,
|
|
||||||
Region *, Region *, Region *> {
|
Region *, Region *, Region *> {
|
||||||
/// The type representing the owner of this range. This is either a list of
|
/// The type representing the owner of this range. This is either a list of
|
||||||
/// values, operands, or results.
|
/// values, operands, or results.
|
||||||
using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
|
using OwnerT = PointerUnion<Region *, const std::unique_ptr<Region> *>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using RangeBaseT::RangeBaseT;
|
using RangeBaseT::RangeBaseT;
|
||||||
|
@ -464,9 +464,9 @@ public:
|
|||||||
Location location);
|
Location location);
|
||||||
|
|
||||||
/// Verify the construction of a unranked memref type.
|
/// Verify the construction of a unranked memref type.
|
||||||
static LogicalResult
|
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
MLIRContext *context,
|
||||||
MLIRContext *context, Type elementType,
|
Type elementType,
|
||||||
unsigned memorySpace);
|
unsigned memorySpace);
|
||||||
|
|
||||||
ArrayRef<int64_t> getShape() const { return llvm::None; }
|
ArrayRef<int64_t> getShape() const { return llvm::None; }
|
||||||
|
@ -84,8 +84,7 @@ private:
|
|||||||
static Type unwrap(Value *value);
|
static Type unwrap(Value *value);
|
||||||
};
|
};
|
||||||
|
|
||||||
using OperandElementTypeRange =
|
using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
|
||||||
llvm::iterator_range<OperandElementTypeIterator>;
|
|
||||||
|
|
||||||
// An iterator for the tensor element types of an op's results of shaped types.
|
// An iterator for the tensor element types of an op's results of shaped types.
|
||||||
class ResultElementTypeIterator final
|
class ResultElementTypeIterator final
|
||||||
@ -102,7 +101,7 @@ private:
|
|||||||
static Type unwrap(Value *value);
|
static Type unwrap(Value *value);
|
||||||
};
|
};
|
||||||
|
|
||||||
using ResultElementTypeRange = llvm::iterator_range<ResultElementTypeIterator>;
|
using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ public:
|
|||||||
inline bool hasOneUse() const;
|
inline bool hasOneUse() const;
|
||||||
|
|
||||||
using use_iterator = ValueUseIterator<IROperand>;
|
using use_iterator = ValueUseIterator<IROperand>;
|
||||||
using use_range = llvm::iterator_range<use_iterator>;
|
using use_range = iterator_range<use_iterator>;
|
||||||
|
|
||||||
inline use_iterator use_begin() const;
|
inline use_iterator use_begin() const;
|
||||||
inline use_iterator use_end() const;
|
inline use_iterator use_end() const;
|
||||||
@ -55,7 +55,7 @@ public:
|
|||||||
inline use_range getUses() const;
|
inline use_range getUses() const;
|
||||||
|
|
||||||
using user_iterator = ValueUserIterator<IROperand>;
|
using user_iterator = ValueUserIterator<IROperand>;
|
||||||
using user_range = llvm::iterator_range<user_iterator>;
|
using user_range = iterator_range<user_iterator>;
|
||||||
|
|
||||||
inline user_iterator user_begin() const;
|
inline user_iterator user_begin() const;
|
||||||
inline user_iterator user_end() const;
|
inline user_iterator user_end() const;
|
||||||
|
4
third_party/mlir/include/mlir/IR/Value.h
vendored
4
third_party/mlir/include/mlir/IR/Value.h
vendored
@ -82,7 +82,7 @@ public:
|
|||||||
Region *getParentRegion();
|
Region *getParentRegion();
|
||||||
|
|
||||||
using use_iterator = ValueUseIterator<OpOperand>;
|
using use_iterator = ValueUseIterator<OpOperand>;
|
||||||
using use_range = llvm::iterator_range<use_iterator>;
|
using use_range = iterator_range<use_iterator>;
|
||||||
|
|
||||||
inline use_iterator use_begin();
|
inline use_iterator use_begin();
|
||||||
inline use_iterator use_end();
|
inline use_iterator use_end();
|
||||||
@ -112,7 +112,7 @@ inline auto Value::use_begin() -> use_iterator {
|
|||||||
|
|
||||||
inline auto Value::use_end() -> use_iterator { return use_iterator(nullptr); }
|
inline auto Value::use_end() -> use_iterator { return use_iterator(nullptr); }
|
||||||
|
|
||||||
inline auto Value::getUses() -> llvm::iterator_range<use_iterator> {
|
inline auto Value::getUses() -> iterator_range<use_iterator> {
|
||||||
return {use_begin(), use_end()};
|
return {use_begin(), use_end()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
third_party/mlir/include/mlir/IR/Visitors.h
vendored
2
third_party/mlir/include/mlir/IR/Visitors.h
vendored
@ -94,7 +94,7 @@ template <
|
|||||||
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
||||||
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
|
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
|
||||||
walkOperations(Operation *op, FuncTy &&callback) {
|
walkOperations(Operation *op, FuncTy &&callback) {
|
||||||
return detail::walkOperations(op, llvm::function_ref<RetT(ArgT)>(callback));
|
return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Walk all of the operations of type 'ArgT' nested under and including the
|
/// Walk all of the operations of type 'ArgT' nested under and including the
|
||||||
|
@ -128,10 +128,10 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
|
|||||||
class AnalysisMap {
|
class AnalysisMap {
|
||||||
/// A mapping between an analysis id and an existing analysis instance.
|
/// A mapping between an analysis id and an existing analysis instance.
|
||||||
using ConceptMap =
|
using ConceptMap =
|
||||||
llvm::DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
|
DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
|
||||||
|
|
||||||
/// Utility to return the name of the given analysis class.
|
/// Utility to return the name of the given analysis class.
|
||||||
template <typename AnalysisT> static llvm::StringRef getAnalysisName() {
|
template <typename AnalysisT> static StringRef getAnalysisName() {
|
||||||
StringRef name = llvm::getTypeName<AnalysisT>();
|
StringRef name = llvm::getTypeName<AnalysisT>();
|
||||||
if (!name.consume_front("mlir::"))
|
if (!name.consume_front("mlir::"))
|
||||||
name.consume_front("(anonymous namespace)::");
|
name.consume_front("(anonymous namespace)::");
|
||||||
@ -165,7 +165,7 @@ public:
|
|||||||
|
|
||||||
/// Get a cached analysis instance if one exists, otherwise return null.
|
/// Get a cached analysis instance if one exists, otherwise return null.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
|
Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
|
||||||
auto res = analyses.find(AnalysisID::getID<AnalysisT>());
|
auto res = analyses.find(AnalysisID::getID<AnalysisT>());
|
||||||
if (res == analyses.end())
|
if (res == analyses.end())
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
@ -206,7 +206,7 @@ struct NestedAnalysisMap {
|
|||||||
void invalidate(const PreservedAnalyses &pa);
|
void invalidate(const PreservedAnalyses &pa);
|
||||||
|
|
||||||
/// The cached analyses for nested operations.
|
/// The cached analyses for nested operations.
|
||||||
llvm::DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
|
DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
|
||||||
|
|
||||||
/// The analyses for the owning module.
|
/// The analyses for the owning module.
|
||||||
detail::AnalysisMap analyses;
|
detail::AnalysisMap analyses;
|
||||||
@ -224,8 +224,8 @@ class ModuleAnalysisManager;
|
|||||||
/// accessible via 'slice'. This class is intended to be passed around by value,
|
/// accessible via 'slice'. This class is intended to be passed around by value,
|
||||||
/// and cannot be constructed directly.
|
/// and cannot be constructed directly.
|
||||||
class AnalysisManager {
|
class AnalysisManager {
|
||||||
using ParentPointerT = llvm::PointerUnion<const ModuleAnalysisManager *,
|
using ParentPointerT =
|
||||||
const AnalysisManager *>;
|
PointerUnion<const ModuleAnalysisManager *, const AnalysisManager *>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using PreservedAnalyses = detail::PreservedAnalyses;
|
using PreservedAnalyses = detail::PreservedAnalyses;
|
||||||
@ -233,7 +233,7 @@ public:
|
|||||||
// Query for a cached analysis on the given parent operation. The analysis may
|
// Query for a cached analysis on the given parent operation. The analysis may
|
||||||
// not exist and if it does it may be out-of-date.
|
// not exist and if it does it may be out-of-date.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>>
|
Optional<std::reference_wrapper<AnalysisT>>
|
||||||
getCachedParentAnalysis(Operation *parentOp) const {
|
getCachedParentAnalysis(Operation *parentOp) const {
|
||||||
ParentPointerT curParent = parent;
|
ParentPointerT curParent = parent;
|
||||||
while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>()) {
|
while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>()) {
|
||||||
@ -251,7 +251,7 @@ public:
|
|||||||
|
|
||||||
// Query for a cached entry of the given analysis on the current operation.
|
// Query for a cached entry of the given analysis on the current operation.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
|
Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
|
||||||
return impl->analyses.getCachedAnalysis<AnalysisT>();
|
return impl->analyses.getCachedAnalysis<AnalysisT>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,7 +262,7 @@ public:
|
|||||||
|
|
||||||
/// Query for a cached analysis of a child operation, or return null.
|
/// Query for a cached analysis of a child operation, or return null.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>>
|
Optional<std::reference_wrapper<AnalysisT>>
|
||||||
getCachedChildAnalysis(Operation *op) const {
|
getCachedChildAnalysis(Operation *op) const {
|
||||||
assert(op->getParentOp() == impl->getOperation());
|
assert(op->getParentOp() == impl->getOperation());
|
||||||
auto it = impl->childAnalyses.find(op);
|
auto it = impl->childAnalyses.find(op);
|
||||||
@ -297,8 +297,7 @@ private:
|
|||||||
|
|
||||||
/// A reference to the parent analysis manager, or the top-level module
|
/// A reference to the parent analysis manager, or the top-level module
|
||||||
/// analysis manager.
|
/// analysis manager.
|
||||||
llvm::PointerUnion<const ModuleAnalysisManager *, const AnalysisManager *>
|
ParentPointerT parent;
|
||||||
parent;
|
|
||||||
|
|
||||||
/// A reference to the impl analysis map within the parent analysis manager.
|
/// A reference to the impl analysis map within the parent analysis manager.
|
||||||
detail::NestedAnalysisMap *impl;
|
detail::NestedAnalysisMap *impl;
|
||||||
|
19
third_party/mlir/include/mlir/Pass/Pass.h
vendored
19
third_party/mlir/include/mlir/Pass/Pass.h
vendored
@ -68,7 +68,7 @@ public:
|
|||||||
|
|
||||||
/// Returns the name of the operation that this pass operates on, or None if
|
/// Returns the name of the operation that this pass operates on, or None if
|
||||||
/// this is a generic OperationPass.
|
/// this is a generic OperationPass.
|
||||||
llvm::Optional<StringRef> getOpName() const { return opName; }
|
Optional<StringRef> getOpName() const { return opName; }
|
||||||
|
|
||||||
/// Prints out the pass in the textual representation of pipelines. If this is
|
/// Prints out the pass in the textual representation of pipelines. If this is
|
||||||
/// an adaptor pass, print with the op_name(sub_pass,...) format.
|
/// an adaptor pass, print with the op_name(sub_pass,...) format.
|
||||||
@ -100,8 +100,7 @@ public:
|
|||||||
MutableArrayRef<Statistic *> getStatistics() { return statistics; }
|
MutableArrayRef<Statistic *> getStatistics() { return statistics; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit Pass(const PassID *passID,
|
explicit Pass(const PassID *passID, Optional<StringRef> opName = llvm::None)
|
||||||
llvm::Optional<StringRef> opName = llvm::None)
|
|
||||||
: passID(passID), opName(opName) {}
|
: passID(passID), opName(opName) {}
|
||||||
|
|
||||||
/// Returns the current pass state.
|
/// Returns the current pass state.
|
||||||
@ -143,10 +142,10 @@ private:
|
|||||||
|
|
||||||
/// The name of the operation that this pass operates on, or None if this is a
|
/// The name of the operation that this pass operates on, or None if this is a
|
||||||
/// generic OperationPass.
|
/// generic OperationPass.
|
||||||
llvm::Optional<StringRef> opName;
|
Optional<StringRef> opName;
|
||||||
|
|
||||||
/// The current execution state for the pass.
|
/// The current execution state for the pass.
|
||||||
llvm::Optional<detail::PassExecutionState> passState;
|
Optional<detail::PassExecutionState> passState;
|
||||||
|
|
||||||
/// The set of statistics held by this pass.
|
/// The set of statistics held by this pass.
|
||||||
std::vector<Statistic *> statistics;
|
std::vector<Statistic *> statistics;
|
||||||
@ -170,7 +169,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit PassModel(llvm::Optional<StringRef> opName = llvm::None)
|
explicit PassModel(Optional<StringRef> opName = llvm::None)
|
||||||
: BasePassT(PassID::getID<PassT>(), opName) {}
|
: BasePassT(PassID::getID<PassT>(), opName) {}
|
||||||
|
|
||||||
/// Signal that some invariant was broken when running. The IR is allowed to
|
/// Signal that some invariant was broken when running. The IR is allowed to
|
||||||
@ -187,7 +186,7 @@ protected:
|
|||||||
/// Query a cached instance of an analysis for the current ir unit if one
|
/// Query a cached instance of an analysis for the current ir unit if one
|
||||||
/// exists.
|
/// exists.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() {
|
Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() {
|
||||||
return this->getAnalysisManager().template getCachedAnalysis<AnalysisT>();
|
return this->getAnalysisManager().template getCachedAnalysis<AnalysisT>();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,13 +218,13 @@ protected:
|
|||||||
|
|
||||||
/// Returns the analysis for the parent operation if it exists.
|
/// Returns the analysis for the parent operation if it exists.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>>
|
Optional<std::reference_wrapper<AnalysisT>>
|
||||||
getCachedParentAnalysis(Operation *parent) {
|
getCachedParentAnalysis(Operation *parent) {
|
||||||
return this->getAnalysisManager()
|
return this->getAnalysisManager()
|
||||||
.template getCachedParentAnalysis<AnalysisT>(parent);
|
.template getCachedParentAnalysis<AnalysisT>(parent);
|
||||||
}
|
}
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
|
Optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
|
||||||
return this->getAnalysisManager()
|
return this->getAnalysisManager()
|
||||||
.template getCachedParentAnalysis<AnalysisT>(
|
.template getCachedParentAnalysis<AnalysisT>(
|
||||||
this->getOperation()->getParentOp());
|
this->getOperation()->getParentOp());
|
||||||
@ -233,7 +232,7 @@ protected:
|
|||||||
|
|
||||||
/// Returns the analysis for the given child operation if it exists.
|
/// Returns the analysis for the given child operation if it exists.
|
||||||
template <typename AnalysisT>
|
template <typename AnalysisT>
|
||||||
llvm::Optional<std::reference_wrapper<AnalysisT>>
|
Optional<std::reference_wrapper<AnalysisT>>
|
||||||
getCachedChildAnalysis(Operation *child) {
|
getCachedChildAnalysis(Operation *child) {
|
||||||
return this->getAnalysisManager()
|
return this->getAnalysisManager()
|
||||||
.template getCachedChildAnalysis<AnalysisT>(child);
|
.template getCachedChildAnalysis<AnalysisT>(child);
|
||||||
|
@ -83,14 +83,14 @@ public:
|
|||||||
/// A callback to run before an analysis is computed. This function takes the
|
/// A callback to run before an analysis is computed. This function takes the
|
||||||
/// name of the analysis to be computed, its AnalysisID, as well as the
|
/// name of the analysis to be computed, its AnalysisID, as well as the
|
||||||
/// current operation being analyzed.
|
/// current operation being analyzed.
|
||||||
virtual void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
|
virtual void runBeforeAnalysis(StringRef name, AnalysisID *id,
|
||||||
Operation *op) {}
|
Operation *op) {}
|
||||||
|
|
||||||
/// A callback to run before an analysis is computed. This function takes the
|
/// A callback to run before an analysis is computed. This function takes the
|
||||||
/// name of the analysis that was computed, its AnalysisID, as well as the
|
/// name of the analysis that was computed, its AnalysisID, as well as the
|
||||||
/// current operation being analyzed.
|
/// current operation being analyzed.
|
||||||
virtual void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
|
virtual void runAfterAnalysis(StringRef name, AnalysisID *id, Operation *op) {
|
||||||
Operation *op) {}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This class holds a collection of PassInstrumentation objects, and invokes
|
/// This class holds a collection of PassInstrumentation objects, and invokes
|
||||||
@ -122,10 +122,10 @@ public:
|
|||||||
void runAfterPassFailed(Pass *pass, Operation *op);
|
void runAfterPassFailed(Pass *pass, Operation *op);
|
||||||
|
|
||||||
/// See PassInstrumentation::runBeforeAnalysis for details.
|
/// See PassInstrumentation::runBeforeAnalysis for details.
|
||||||
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op);
|
void runBeforeAnalysis(StringRef name, AnalysisID *id, Operation *op);
|
||||||
|
|
||||||
/// See PassInstrumentation::runAfterAnalysis for details.
|
/// See PassInstrumentation::runAfterAnalysis for details.
|
||||||
void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, Operation *op);
|
void runAfterAnalysis(StringRef name, AnalysisID *id, Operation *op);
|
||||||
|
|
||||||
/// Add the given instrumentation to the collection.
|
/// Add the given instrumentation to the collection.
|
||||||
void addInstrumentation(std::unique_ptr<PassInstrumentation> pi);
|
void addInstrumentation(std::unique_ptr<PassInstrumentation> pi);
|
||||||
|
@ -62,7 +62,7 @@ public:
|
|||||||
llvm::pointee_iterator<std::vector<std::unique_ptr<Pass>>::iterator>;
|
llvm::pointee_iterator<std::vector<std::unique_ptr<Pass>>::iterator>;
|
||||||
pass_iterator begin();
|
pass_iterator begin();
|
||||||
pass_iterator end();
|
pass_iterator end();
|
||||||
llvm::iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
|
iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
|
||||||
|
|
||||||
/// Run the held passes over the given operation.
|
/// Run the held passes over the given operation.
|
||||||
LogicalResult run(Operation *op, AnalysisManager am);
|
LogicalResult run(Operation *op, AnalysisManager am);
|
||||||
|
@ -74,7 +74,7 @@ public:
|
|||||||
return candidateTypes[index];
|
return candidateTypes[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::ArrayRef<CandidateQuantizedType> getCandidateTypes() const {
|
ArrayRef<CandidateQuantizedType> getCandidateTypes() const {
|
||||||
return candidateTypes;
|
return candidateTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Gets a mask with every candidate type except those in the given mask.
|
/// Gets a mask with every candidate type except those in the given mask.
|
||||||
llvm::SmallBitVector getCandidateTypeDisabledExceptMask(
|
llvm::SmallBitVector
|
||||||
llvm::ArrayRef<unsigned> exceptOrdinals) const {
|
getCandidateTypeDisabledExceptMask(ArrayRef<unsigned> exceptOrdinals) const {
|
||||||
llvm::SmallBitVector disabled(allCandidateTypesMask);
|
llvm::SmallBitVector disabled(allCandidateTypesMask);
|
||||||
for (unsigned ordinal : exceptOrdinals) {
|
for (unsigned ordinal : exceptOrdinals) {
|
||||||
disabled.reset(ordinal);
|
disabled.reset(ordinal);
|
||||||
|
@ -68,7 +68,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Vector and iterator over nodes.
|
// Vector and iterator over nodes.
|
||||||
using node_vector = llvm::SmallVector<CAGNode *, 1>;
|
using node_vector = SmallVector<CAGNode *, 1>;
|
||||||
using iterator = node_vector::iterator;
|
using iterator = node_vector::iterator;
|
||||||
using const_iterator = node_vector::const_iterator;
|
using const_iterator = node_vector::const_iterator;
|
||||||
|
|
||||||
@ -100,12 +100,11 @@ public:
|
|||||||
const TargetConfiguration &config) {}
|
const TargetConfiguration &config) {}
|
||||||
|
|
||||||
/// Prints the node label, suitable for one-line display.
|
/// Prints the node label, suitable for one-line display.
|
||||||
virtual void printLabel(llvm::raw_ostream &os) const;
|
virtual void printLabel(raw_ostream &os) const;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T> void findChildrenOfKind(SmallVectorImpl<T *> &found) {
|
||||||
void findChildrenOfKind(llvm::SmallVectorImpl<T *> &found) {
|
|
||||||
for (CAGNode *child : *this) {
|
for (CAGNode *child : *this) {
|
||||||
T *ofKind = llvm::dyn_cast<T>(child);
|
T *ofKind = dyn_cast<T>(child);
|
||||||
if (ofKind) {
|
if (ofKind) {
|
||||||
found.push_back(ofKind);
|
found.push_back(ofKind);
|
||||||
}
|
}
|
||||||
@ -173,7 +172,7 @@ public:
|
|||||||
void propagate(SolverContext &solverContext,
|
void propagate(SolverContext &solverContext,
|
||||||
const TargetConfiguration &config) override;
|
const TargetConfiguration &config) override;
|
||||||
|
|
||||||
void printLabel(llvm::raw_ostream &os) const override;
|
void printLabel(raw_ostream &os) const override;
|
||||||
|
|
||||||
/// Given the anchor metadata and resolved solutions, chooses the most
|
/// Given the anchor metadata and resolved solutions, chooses the most
|
||||||
/// salient and returns an appropriate type to represent it.
|
/// salient and returns an appropriate type to represent it.
|
||||||
@ -213,7 +212,7 @@ public:
|
|||||||
|
|
||||||
Value *getValue() const final { return op->getOperand(operandIdx); }
|
Value *getValue() const final { return op->getOperand(operandIdx); }
|
||||||
|
|
||||||
void printLabel(llvm::raw_ostream &os) const override;
|
void printLabel(raw_ostream &os) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Operation *op;
|
Operation *op;
|
||||||
@ -234,7 +233,7 @@ public:
|
|||||||
Operation *getOp() const final { return resultValue->getDefiningOp(); }
|
Operation *getOp() const final { return resultValue->getDefiningOp(); }
|
||||||
Value *getValue() const final { return resultValue; }
|
Value *getValue() const final { return resultValue; }
|
||||||
|
|
||||||
void printLabel(llvm::raw_ostream &os) const override;
|
void printLabel(raw_ostream &os) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Value *resultValue;
|
Value *resultValue;
|
||||||
@ -275,8 +274,7 @@ public:
|
|||||||
/// Adds a relation constraint with incoming 'from' anchors and outgoing 'to'
|
/// Adds a relation constraint with incoming 'from' anchors and outgoing 'to'
|
||||||
/// anchors.
|
/// anchors.
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
T *addUniqueConstraint(llvm::ArrayRef<CAGAnchorNode *> anchors,
|
T *addUniqueConstraint(ArrayRef<CAGAnchorNode *> anchors, Args... args) {
|
||||||
Args... args) {
|
|
||||||
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
||||||
"T must be a CAGConstraingNode");
|
"T must be a CAGConstraingNode");
|
||||||
T *constraintNode = addNode(std::make_unique<T>(args...));
|
T *constraintNode = addNode(std::make_unique<T>(args...));
|
||||||
@ -288,7 +286,7 @@ public:
|
|||||||
/// Adds a unidirectional constraint from a node to an array of target nodes.
|
/// Adds a unidirectional constraint from a node to an array of target nodes.
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor,
|
T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor,
|
||||||
llvm::ArrayRef<CAGAnchorNode *> toAnchors,
|
ArrayRef<CAGAnchorNode *> toAnchors,
|
||||||
Args... args) {
|
Args... args) {
|
||||||
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
||||||
"T must be a CAGConstraingNode");
|
"T must be a CAGConstraingNode");
|
||||||
@ -301,10 +299,10 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T *addClusteredConstraint(llvm::ArrayRef<CAGAnchorNode *> anchors) {
|
T *addClusteredConstraint(ArrayRef<CAGAnchorNode *> anchors) {
|
||||||
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
static_assert(std::is_convertible<T *, CAGConstraintNode *>(),
|
||||||
"T must be a CAGConstraingNode");
|
"T must be a CAGConstraingNode");
|
||||||
llvm::SmallVector<T *, 8> cluster;
|
SmallVector<T *, 8> cluster;
|
||||||
for (auto *anchor : anchors) {
|
for (auto *anchor : anchors) {
|
||||||
anchor->findChildrenOfKind<T>(cluster);
|
anchor->findChildrenOfKind<T>(cluster);
|
||||||
}
|
}
|
||||||
@ -356,14 +354,11 @@ private:
|
|||||||
|
|
||||||
SolverContext &context;
|
SolverContext &context;
|
||||||
std::vector<CAGNode *> allNodes;
|
std::vector<CAGNode *> allNodes;
|
||||||
llvm::DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *>
|
DenseMap<std::pair<Operation *, unsigned>, CAGOperandAnchor *> operandAnchors;
|
||||||
operandAnchors;
|
DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *> resultAnchors;
|
||||||
llvm::DenseMap<std::pair<Operation *, unsigned>, CAGResultAnchor *>
|
|
||||||
resultAnchors;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
inline raw_ostream &operator<<(raw_ostream &os, const CAGNode &node) {
|
||||||
const CAGNode &node) {
|
|
||||||
node.printLabel(os);
|
node.printLabel(os);
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -101,7 +101,7 @@ struct CAGUniformMetadata {
|
|||||||
DiscreteScaleZeroPointFact explicitScaleZeroPoint;
|
DiscreteScaleZeroPointFact explicitScaleZeroPoint;
|
||||||
|
|
||||||
/// Prints a summary of the metadata suitable for display in a graph label.
|
/// Prints a summary of the metadata suitable for display in a graph label.
|
||||||
void printSummary(llvm::raw_ostream &os) const;
|
void printSummary(raw_ostream &os) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace quantizer
|
} // end namespace quantizer
|
||||||
|
@ -85,8 +85,7 @@ private:
|
|||||||
Attribute attr;
|
Attribute attr;
|
||||||
};
|
};
|
||||||
|
|
||||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats);
|
||||||
const TensorAxisStatistics &stats);
|
|
||||||
|
|
||||||
} // end namespace quantizer
|
} // end namespace quantizer
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
#ifndef MLIR_SUPPORT_FUNCTIONAL_H_
|
#ifndef MLIR_SUPPORT_FUNCTIONAL_H_
|
||||||
#define MLIR_SUPPORT_FUNCTIONAL_H_
|
#define MLIR_SUPPORT_FUNCTIONAL_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
@ -34,10 +35,9 @@ namespace functional {
|
|||||||
/// Map with iterators.
|
/// Map with iterators.
|
||||||
template <typename Fn, typename IterType>
|
template <typename Fn, typename IterType>
|
||||||
auto map(Fn fun, IterType begin, IterType end)
|
auto map(Fn fun, IterType begin, IterType end)
|
||||||
-> llvm::SmallVector<typename std::result_of<Fn(decltype(*begin))>::type,
|
-> SmallVector<typename std::result_of<Fn(decltype(*begin))>::type, 8> {
|
||||||
8> {
|
|
||||||
using R = typename std::result_of<Fn(decltype(*begin))>::type;
|
using R = typename std::result_of<Fn(decltype(*begin))>::type;
|
||||||
llvm::SmallVector<R, 8> res;
|
SmallVector<R, 8> res;
|
||||||
// auto i works with both pointer types and value types with an operator*.
|
// auto i works with both pointer types and value types with an operator*.
|
||||||
// auto *i only works for pointer types.
|
// auto *i only works for pointer types.
|
||||||
for (auto i = begin; i != end; ++i) {
|
for (auto i = begin; i != end; ++i) {
|
||||||
@ -58,13 +58,12 @@ auto map(Fn fun, ContainerType input)
|
|||||||
/// TODO(ntv): make variadic when needed.
|
/// TODO(ntv): make variadic when needed.
|
||||||
template <typename Fn, typename ContainerType1, typename ContainerType2>
|
template <typename Fn, typename ContainerType1, typename ContainerType2>
|
||||||
auto zipMap(Fn fun, ContainerType1 input1, ContainerType2 input2)
|
auto zipMap(Fn fun, ContainerType1 input1, ContainerType2 input2)
|
||||||
-> llvm::SmallVector<
|
-> SmallVector<typename std::result_of<Fn(decltype(*input1.begin()),
|
||||||
typename std::result_of<Fn(decltype(*input1.begin()),
|
|
||||||
decltype(*input2.begin()))>::type,
|
decltype(*input2.begin()))>::type,
|
||||||
8> {
|
8> {
|
||||||
using R = typename std::result_of<Fn(decltype(*input1.begin()),
|
using R = typename std::result_of<Fn(decltype(*input1.begin()),
|
||||||
decltype(*input2.begin()))>::type;
|
decltype(*input2.begin()))>::type;
|
||||||
llvm::SmallVector<R, 8> res;
|
SmallVector<R, 8> res;
|
||||||
auto zipIter = llvm::zip(input1, input2);
|
auto zipIter = llvm::zip(input1, input2);
|
||||||
for (auto it : zipIter) {
|
for (auto it : zipIter) {
|
||||||
res.push_back(fun(std::get<0>(it), std::get<1>(it)));
|
res.push_back(fun(std::get<0>(it), std::get<1>(it)));
|
||||||
@ -104,7 +103,7 @@ void zipApply(Fn fun, ContainerType1 input1, ContainerType2 input2) {
|
|||||||
/// Operation::operand_range types.
|
/// Operation::operand_range types.
|
||||||
template <typename T, typename ToType = T>
|
template <typename T, typename ToType = T>
|
||||||
inline std::function<ToType *(T *)> makePtrDynCaster() {
|
inline std::function<ToType *(T *)> makePtrDynCaster() {
|
||||||
return [](T *val) { return llvm::dyn_cast<ToType>(val); };
|
return [](T *val) { return dyn_cast<ToType>(val); };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Simple ScopeGuard.
|
/// Simple ScopeGuard.
|
||||||
|
@ -26,6 +26,7 @@ namespace llvm {
|
|||||||
class raw_ostream;
|
class raw_ostream;
|
||||||
class MemoryBuffer;
|
class MemoryBuffer;
|
||||||
} // end namespace llvm
|
} // end namespace llvm
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
struct LogicalResult;
|
struct LogicalResult;
|
||||||
class PassPipelineCLParser;
|
class PassPipelineCLParser;
|
||||||
|
@ -202,7 +202,7 @@ private:
|
|||||||
/// Implementation for getting/creating an instance of a derived type with
|
/// Implementation for getting/creating an instance of a derived type with
|
||||||
/// complex storage.
|
/// complex storage.
|
||||||
BaseStorage *getImpl(unsigned kind, unsigned hashValue,
|
BaseStorage *getImpl(unsigned kind, unsigned hashValue,
|
||||||
llvm::function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
std::function<BaseStorage *(StorageAllocator &)> ctorFn);
|
std::function<BaseStorage *(StorageAllocator &)> ctorFn);
|
||||||
|
|
||||||
/// Implementation for getting/creating an instance of a derived type with
|
/// Implementation for getting/creating an instance of a derived type with
|
||||||
@ -213,7 +213,7 @@ private:
|
|||||||
/// Implementation for erasing an instance of a derived type with complex
|
/// Implementation for erasing an instance of a derived type with complex
|
||||||
/// storage.
|
/// storage.
|
||||||
void eraseImpl(unsigned kind, unsigned hashValue,
|
void eraseImpl(unsigned kind, unsigned hashValue,
|
||||||
llvm::function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
std::function<void(BaseStorage *)> cleanupFn);
|
std::function<void(BaseStorage *)> cleanupFn);
|
||||||
|
|
||||||
/// The internal implementation class.
|
/// The internal implementation class.
|
||||||
@ -263,7 +263,7 @@ private:
|
|||||||
::llvm::hash_code>::type
|
::llvm::hash_code>::type
|
||||||
getHash(unsigned kind, const DerivedKey &derivedKey) {
|
getHash(unsigned kind, const DerivedKey &derivedKey) {
|
||||||
return llvm::hash_combine(
|
return llvm::hash_combine(
|
||||||
kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
|
kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
#ifndef MLIR_SUPPORT_TOOLUTILITIES_H
|
#ifndef MLIR_SUPPORT_TOOLUTILITIES_H
|
||||||
#define MLIR_SUPPORT_TOOLUTILITIES_H
|
#define MLIR_SUPPORT_TOOLUTILITIES_H
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
@ -32,8 +33,8 @@ class MemoryBuffer;
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
struct LogicalResult;
|
struct LogicalResult;
|
||||||
|
|
||||||
using ChunkBufferHandler = llvm::function_ref<LogicalResult(
|
using ChunkBufferHandler = function_ref<LogicalResult(
|
||||||
std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, llvm::raw_ostream &os)>;
|
std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os)>;
|
||||||
|
|
||||||
/// Splits the specified buffer on a marker (`// -----`), processes each chunk
|
/// Splits the specified buffer on a marker (`// -----`), processes each chunk
|
||||||
/// independently according to the normal `processChunkBuffer` logic, and writes
|
/// independently according to the normal `processChunkBuffer` logic, and writes
|
||||||
@ -43,8 +44,7 @@ using ChunkBufferHandler = llvm::function_ref<LogicalResult(
|
|||||||
/// into a single file.
|
/// into a single file.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
|
splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
|
||||||
ChunkBufferHandler processChunkBuffer,
|
ChunkBufferHandler processChunkBuffer, raw_ostream &os);
|
||||||
llvm::raw_ostream &os);
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_SUPPORT_TOOLUTILITIES_H
|
#endif // MLIR_SUPPORT_TOOLUTILITIES_H
|
||||||
|
@ -46,7 +46,7 @@ namespace tblgen {
|
|||||||
// is shared among multiple patterns to avoid creating the wrapper object for
|
// is shared among multiple patterns to avoid creating the wrapper object for
|
||||||
// the same op again and again. But this map will continuously grow.
|
// the same op again and again. But this map will continuously grow.
|
||||||
using RecordOperatorMap =
|
using RecordOperatorMap =
|
||||||
llvm::DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
|
DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
|
||||||
|
|
||||||
class Pattern;
|
class Pattern;
|
||||||
|
|
||||||
|
@ -116,13 +116,13 @@ private:
|
|||||||
std::unique_ptr<llvm::Module> llvmModule;
|
std::unique_ptr<llvm::Module> llvmModule;
|
||||||
|
|
||||||
// Mappings between llvm.mlir.global definitions and corresponding globals.
|
// Mappings between llvm.mlir.global definitions and corresponding globals.
|
||||||
llvm::DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
|
DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Mappings between original and translated values, used for lookups.
|
// Mappings between original and translated values, used for lookups.
|
||||||
llvm::StringMap<llvm::Function *> functionMapping;
|
llvm::StringMap<llvm::Function *> functionMapping;
|
||||||
llvm::DenseMap<Value *, llvm::Value *> valueMapping;
|
DenseMap<Value *, llvm::Value *> valueMapping;
|
||||||
llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
|
DenseMap<Block *, llvm::BasicBlock *> blockMapping;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
|
@ -94,7 +94,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
/// The remapping information for each of the original arguments.
|
/// The remapping information for each of the original arguments.
|
||||||
SmallVector<llvm::Optional<InputMapping>, 4> remappedInputs;
|
SmallVector<Optional<InputMapping>, 4> remappedInputs;
|
||||||
|
|
||||||
/// The set of new argument types.
|
/// The set of new argument types.
|
||||||
SmallVector<Type, 4> argTypes;
|
SmallVector<Type, 4> argTypes;
|
||||||
@ -133,7 +133,7 @@ public:
|
|||||||
/// This function converts the type signature of the given block, by invoking
|
/// This function converts the type signature of the given block, by invoking
|
||||||
/// 'convertSignatureArg' for each argument. This function should return a
|
/// 'convertSignatureArg' for each argument. This function should return a
|
||||||
/// valid conversion for the signature on success, None otherwise.
|
/// valid conversion for the signature on success, None otherwise.
|
||||||
llvm::Optional<SignatureConversion> convertBlockSignature(Block *block);
|
Optional<SignatureConversion> convertBlockSignature(Block *block);
|
||||||
|
|
||||||
/// This hook allows for materializing a conversion from a set of types into
|
/// This hook allows for materializing a conversion from a set of types into
|
||||||
/// one result type by generating a cast operation of some kind. The generated
|
/// one result type by generating a cast operation of some kind. The generated
|
||||||
@ -236,13 +236,13 @@ struct OpConversionPattern : public ConversionPattern {
|
|||||||
/// type.
|
/// type.
|
||||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
rewrite(llvm::cast<SourceOp>(op), operands, rewriter);
|
rewrite(cast<SourceOp>(op), operands, rewriter);
|
||||||
}
|
}
|
||||||
void rewrite(Operation *op, ArrayRef<Value *> properOperands,
|
void rewrite(Operation *op, ArrayRef<Value *> properOperands,
|
||||||
ArrayRef<Block *> destinations,
|
ArrayRef<Block *> destinations,
|
||||||
ArrayRef<ArrayRef<Value *>> operands,
|
ArrayRef<ArrayRef<Value *>> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
rewrite(llvm::cast<SourceOp>(op), properOperands, destinations, operands,
|
rewrite(cast<SourceOp>(op), properOperands, destinations, operands,
|
||||||
rewriter);
|
rewriter);
|
||||||
}
|
}
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
@ -250,13 +250,13 @@ struct OpConversionPattern : public ConversionPattern {
|
|||||||
ArrayRef<Block *> destinations,
|
ArrayRef<Block *> destinations,
|
||||||
ArrayRef<ArrayRef<Value *>> operands,
|
ArrayRef<ArrayRef<Value *>> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
return matchAndRewrite(llvm::cast<SourceOp>(op), properOperands,
|
return matchAndRewrite(cast<SourceOp>(op), properOperands, destinations,
|
||||||
destinations, operands, rewriter);
|
operands, rewriter);
|
||||||
}
|
}
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
return matchAndRewrite(llvm::cast<SourceOp>(op), operands, rewriter);
|
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/142763075): Use OperandAdaptor when it supports access to unnamed
|
// TODO(b/142763075): Use OperandAdaptor when it supports access to unnamed
|
||||||
|
@ -66,10 +66,10 @@ public:
|
|||||||
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
|
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
|
||||||
/// operations generated when folding. If the op was completely folded it is
|
/// operations generated when folding. If the op was completely folded it is
|
||||||
/// erased.
|
/// erased.
|
||||||
LogicalResult tryToFold(
|
LogicalResult
|
||||||
Operation *op,
|
tryToFold(Operation *op,
|
||||||
llvm::function_ref<void(Operation *)> processGeneratedConstants = nullptr,
|
function_ref<void(Operation *)> processGeneratedConstants = nullptr,
|
||||||
llvm::function_ref<void(Operation *)> preReplaceAction = nullptr);
|
function_ref<void(Operation *)> preReplaceAction = nullptr);
|
||||||
|
|
||||||
/// Notifies that the given constant `op` should be remove from this
|
/// Notifies that the given constant `op` should be remove from this
|
||||||
/// OperationFolder's internal bookkeeping.
|
/// OperationFolder's internal bookkeeping.
|
||||||
@ -125,9 +125,9 @@ private:
|
|||||||
|
|
||||||
/// Tries to perform folding on the given `op`. If successful, populates
|
/// Tries to perform folding on the given `op`. If successful, populates
|
||||||
/// `results` with the results of the folding.
|
/// `results` with the results of the folding.
|
||||||
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results,
|
LogicalResult tryToFold(
|
||||||
llvm::function_ref<void(Operation *)>
|
Operation *op, SmallVectorImpl<Value *> &results,
|
||||||
processGeneratedConstants = nullptr);
|
function_ref<void(Operation *)> processGeneratedConstants = nullptr);
|
||||||
|
|
||||||
/// Try to get or create a new constant entry. On success this returns the
|
/// Try to get or create a new constant entry. On success this returns the
|
||||||
/// constant operation, nullptr otherwise.
|
/// constant operation, nullptr otherwise.
|
||||||
|
@ -143,7 +143,7 @@ public:
|
|||||||
/// Process a set of blocks that have been inlined. This callback is invoked
|
/// Process a set of blocks that have been inlined. This callback is invoked
|
||||||
/// *before* inlined terminator operations have been processed.
|
/// *before* inlined terminator operations have been processed.
|
||||||
virtual void
|
virtual void
|
||||||
processInlinedBlocks(llvm::iterator_range<Region::iterator> inlinedBlocks) {}
|
processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) {}
|
||||||
|
|
||||||
/// These hooks mirror the hooks for the DialectInlinerInterface, with default
|
/// These hooks mirror the hooks for the DialectInlinerInterface, with default
|
||||||
/// implementations that call the hook on the handler for the dialect 'op' is
|
/// implementations that call the hook on the handler for the dialect 'op' is
|
||||||
@ -188,7 +188,7 @@ public:
|
|||||||
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
|
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
|
||||||
Operation *inlinePoint, BlockAndValueMapping &mapper,
|
Operation *inlinePoint, BlockAndValueMapping &mapper,
|
||||||
ArrayRef<Value *> resultsToReplace,
|
ArrayRef<Value *> resultsToReplace,
|
||||||
llvm::Optional<Location> inlineLoc = llvm::None,
|
Optional<Location> inlineLoc = llvm::None,
|
||||||
bool shouldCloneInlinedRegion = true);
|
bool shouldCloneInlinedRegion = true);
|
||||||
|
|
||||||
/// This function is an overload of the above 'inlineRegion' that allows for
|
/// This function is an overload of the above 'inlineRegion' that allows for
|
||||||
@ -198,7 +198,7 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
|
|||||||
Operation *inlinePoint,
|
Operation *inlinePoint,
|
||||||
ArrayRef<Value *> inlinedOperands,
|
ArrayRef<Value *> inlinedOperands,
|
||||||
ArrayRef<Value *> resultsToReplace,
|
ArrayRef<Value *> resultsToReplace,
|
||||||
llvm::Optional<Location> inlineLoc = llvm::None,
|
Optional<Location> inlineLoc = llvm::None,
|
||||||
bool shouldCloneInlinedRegion = true);
|
bool shouldCloneInlinedRegion = true);
|
||||||
|
|
||||||
/// This function inlines a given region, 'src', of a callable operation,
|
/// This function inlines a given region, 'src', of a callable operation,
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
#ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
||||||
#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
#define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
@ -64,11 +65,11 @@ FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
|
|||||||
/// loop body.
|
/// loop body.
|
||||||
struct LoopNestStats {
|
struct LoopNestStats {
|
||||||
/// Map from AffineForOp to immediate child AffineForOps in its loop body.
|
/// Map from AffineForOp to immediate child AffineForOps in its loop body.
|
||||||
llvm::DenseMap<Operation *, llvm::SmallVector<AffineForOp, 2>> loopMap;
|
DenseMap<Operation *, SmallVector<AffineForOp, 2>> loopMap;
|
||||||
/// Map from AffineForOp to count of operations in its loop body.
|
/// Map from AffineForOp to count of operations in its loop body.
|
||||||
llvm::DenseMap<Operation *, uint64_t> opCountMap;
|
DenseMap<Operation *, uint64_t> opCountMap;
|
||||||
/// Map from AffineForOp to its constant trip count.
|
/// Map from AffineForOp to its constant trip count.
|
||||||
llvm::DenseMap<Operation *, uint64_t> tripCountMap;
|
DenseMap<Operation *, uint64_t> tripCountMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Collect loop nest statistics (eg. loop trip count and operation count)
|
/// Collect loop nest statistics (eg. loop trip count and operation count)
|
||||||
|
@ -44,7 +44,7 @@ std::unique_ptr<Pass> createCSEPass();
|
|||||||
/// Creates a pass to vectorize loops, operations and data types using a
|
/// Creates a pass to vectorize loops, operations and data types using a
|
||||||
/// target-independent, n-D super-vector abstraction.
|
/// target-independent, n-D super-vector abstraction.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OpPassBase<FuncOp>>
|
||||||
createVectorizePass(llvm::ArrayRef<int64_t> virtualVectorSize);
|
createVectorizePass(ArrayRef<int64_t> virtualVectorSize);
|
||||||
|
|
||||||
/// Creates a pass to allow independent testing of vectorizer functionality with
|
/// Creates a pass to allow independent testing of vectorizer functionality with
|
||||||
/// FileCheck.
|
/// FileCheck.
|
||||||
@ -52,7 +52,7 @@ std::unique_ptr<OpPassBase<FuncOp>> createVectorizerTestPass();
|
|||||||
|
|
||||||
/// Creates a pass to lower super-vectors to target-dependent HW vectors.
|
/// Creates a pass to lower super-vectors to target-dependent HW vectors.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OpPassBase<FuncOp>>
|
||||||
createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize);
|
createMaterializeVectorsPass(ArrayRef<int64_t> vectorSize);
|
||||||
|
|
||||||
/// Creates a loop unrolling pass with the provided parameters.
|
/// Creates a loop unrolling pass with the provided parameters.
|
||||||
/// 'getUnrollFactor' is a function callback for clients to supply a function
|
/// 'getUnrollFactor' is a function callback for clients to supply a function
|
||||||
|
@ -47,7 +47,7 @@ void visitUsedValuesDefinedAbove(Region ®ion, Region &limit,
|
|||||||
|
|
||||||
/// Calls `callback` for each use of a value within any of the regions provided
|
/// Calls `callback` for each use of a value within any of the regions provided
|
||||||
/// that was defined in one of the ancestors.
|
/// that was defined in one of the ancestors.
|
||||||
void visitUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions,
|
void visitUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
|
||||||
function_ref<void(OpOperand *)> callback);
|
function_ref<void(OpOperand *)> callback);
|
||||||
|
|
||||||
/// Fill `values` with a list of values defined at the ancestors of the `limit`
|
/// Fill `values` with a list of values defined at the ancestors of the `limit`
|
||||||
@ -57,14 +57,14 @@ void getUsedValuesDefinedAbove(Region ®ion, Region &limit,
|
|||||||
|
|
||||||
/// Fill `values` with a list of values used within any of the regions provided
|
/// Fill `values` with a list of values used within any of the regions provided
|
||||||
/// but defined in one of the ancestors.
|
/// but defined in one of the ancestors.
|
||||||
void getUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions,
|
void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
|
||||||
llvm::SetVector<Value *> &values);
|
llvm::SetVector<Value *> &values);
|
||||||
|
|
||||||
/// Run a set of structural simplifications over the given regions. This
|
/// Run a set of structural simplifications over the given regions. This
|
||||||
/// includes transformations like unreachable block elimination, dead argument
|
/// includes transformations like unreachable block elimination, dead argument
|
||||||
/// elimination, as well as some other DCE. This function returns success if any
|
/// elimination, as well as some other DCE. This function returns success if any
|
||||||
/// of the regions were simplified, failure otherwise.
|
/// of the regions were simplified, failure otherwise.
|
||||||
LogicalResult simplifyRegions(llvm::MutableArrayRef<Region> regions);
|
LogicalResult simplifyRegions(MutableArrayRef<Region> regions);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@ -37,13 +37,13 @@ void viewGraph(Block &block, const Twine &name, bool shortNames = false,
|
|||||||
const Twine &title = "",
|
const Twine &title = "",
|
||||||
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
|
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
|
||||||
|
|
||||||
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Block &block,
|
raw_ostream &writeGraph(raw_ostream &os, Block &block, bool shortNames = false,
|
||||||
bool shortNames = false, const Twine &title = "");
|
const Twine &title = "");
|
||||||
|
|
||||||
/// Creates a pass to print op graphs.
|
/// Creates a pass to print op graphs.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>>
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
||||||
createPrintOpGraphPass(llvm::raw_ostream &os = llvm::errs(),
|
createPrintOpGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
|
||||||
bool shortNames = false, const llvm::Twine &title = "");
|
const Twine &title = "");
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
@ -37,13 +37,13 @@ void viewGraph(Region ®ion, const Twine &name, bool shortNames = false,
|
|||||||
const Twine &title = "",
|
const Twine &title = "",
|
||||||
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
|
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
|
||||||
|
|
||||||
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Region ®ion,
|
raw_ostream &writeGraph(raw_ostream &os, Region ®ion,
|
||||||
bool shortNames = false, const Twine &title = "");
|
bool shortNames = false, const Twine &title = "");
|
||||||
|
|
||||||
/// Creates a pass to print CFG graphs.
|
/// Creates a pass to print CFG graphs.
|
||||||
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>>
|
std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>>
|
||||||
createPrintCFGGraphPass(llvm::raw_ostream &os = llvm::errs(),
|
createPrintCFGGraphPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
|
||||||
bool shortNames = false, const llvm::Twine &title = "");
|
const Twine &title = "");
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
@ -619,7 +619,7 @@ static void computeDirectionVector(
|
|||||||
const FlatAffineConstraints &srcDomain,
|
const FlatAffineConstraints &srcDomain,
|
||||||
const FlatAffineConstraints &dstDomain, unsigned loopDepth,
|
const FlatAffineConstraints &dstDomain, unsigned loopDepth,
|
||||||
FlatAffineConstraints *dependenceDomain,
|
FlatAffineConstraints *dependenceDomain,
|
||||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
|
SmallVector<DependenceComponent, 2> *dependenceComponents) {
|
||||||
// Find the number of common loops shared by src and dst accesses.
|
// Find the number of common loops shared by src and dst accesses.
|
||||||
SmallVector<AffineForOp, 4> commonLoops;
|
SmallVector<AffineForOp, 4> commonLoops;
|
||||||
unsigned numCommonLoops =
|
unsigned numCommonLoops =
|
||||||
@ -772,8 +772,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
|
|||||||
DependenceResult mlir::checkMemrefAccessDependence(
|
DependenceResult mlir::checkMemrefAccessDependence(
|
||||||
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
|
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
|
||||||
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
|
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
|
||||||
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
|
SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) {
|
||||||
bool allowRAR) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
|
LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
|
||||||
<< Twine(loopDepth) << " between:\n";);
|
<< Twine(loopDepth) << " between:\n";);
|
||||||
LLVM_DEBUG(srcAccess.opInst->dump(););
|
LLVM_DEBUG(srcAccess.opInst->dump(););
|
||||||
@ -865,7 +864,7 @@ DependenceResult mlir::checkMemrefAccessDependence(
|
|||||||
/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
|
/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
|
||||||
void mlir::getDependenceComponents(
|
void mlir::getDependenceComponents(
|
||||||
AffineForOp forOp, unsigned maxLoopDepth,
|
AffineForOp forOp, unsigned maxLoopDepth,
|
||||||
std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec) {
|
std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) {
|
||||||
// Collect all load and store ops in loop nest rooted at 'forOp'.
|
// Collect all load and store ops in loop nest rooted at 'forOp'.
|
||||||
SmallVector<Operation *, 8> loadAndStoreOpInsts;
|
SmallVector<Operation *, 8> loadAndStoreOpInsts;
|
||||||
forOp.getOperation()->walk([&](Operation *opInst) {
|
forOp.getOperation()->walk([&](Operation *opInst) {
|
||||||
@ -883,7 +882,7 @@ void mlir::getDependenceComponents(
|
|||||||
MemRefAccess dstAccess(dstOpInst);
|
MemRefAccess dstAccess(dstOpInst);
|
||||||
|
|
||||||
FlatAffineConstraints dependenceConstraints;
|
FlatAffineConstraints dependenceConstraints;
|
||||||
llvm::SmallVector<DependenceComponent, 2> depComps;
|
SmallVector<DependenceComponent, 2> depComps;
|
||||||
// TODO(andydavis,bondhugula) Explore whether it would be profitable
|
// TODO(andydavis,bondhugula) Explore whether it would be profitable
|
||||||
// to pre-compute and store deps instead of repeatedly checking.
|
// to pre-compute and store deps instead of repeatedly checking.
|
||||||
DependenceResult result = checkMemrefAccessDependence(
|
DependenceResult result = checkMemrefAccessDependence(
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/AffineExprVisitor.h"
|
#include "mlir/IR/AffineExprVisitor.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Support/MathExtras.h"
|
#include "mlir/Support/MathExtras.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
@ -34,7 +35,6 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using llvm::SmallDenseMap;
|
using llvm::SmallDenseMap;
|
||||||
using llvm::SmallDenseSet;
|
using llvm::SmallDenseSet;
|
||||||
using llvm::SmallPtrSet;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -73,9 +73,10 @@ private:
|
|||||||
|
|
||||||
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
|
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
|
||||||
// flattened (i.e., semi-affine expressions not handled yet).
|
// flattened (i.e., semi-affine expressions not handled yet).
|
||||||
static LogicalResult getFlattenedAffineExprs(
|
static LogicalResult
|
||||||
ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
|
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
|
||||||
std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
|
unsigned numSymbols,
|
||||||
|
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
|
||||||
FlatAffineConstraints *localVarCst) {
|
FlatAffineConstraints *localVarCst) {
|
||||||
if (exprs.empty()) {
|
if (exprs.empty()) {
|
||||||
localVarCst->reset(numDims, numSymbols);
|
localVarCst->reset(numDims, numSymbols);
|
||||||
@ -109,7 +110,7 @@ static LogicalResult getFlattenedAffineExprs(
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
||||||
unsigned numSymbols,
|
unsigned numSymbols,
|
||||||
llvm::SmallVectorImpl<int64_t> *flattenedExpr,
|
SmallVectorImpl<int64_t> *flattenedExpr,
|
||||||
FlatAffineConstraints *localVarCst) {
|
FlatAffineConstraints *localVarCst) {
|
||||||
std::vector<SmallVector<int64_t, 8>> flattenedExprs;
|
std::vector<SmallVector<int64_t, 8>> flattenedExprs;
|
||||||
LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
|
LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
|
||||||
@ -121,7 +122,7 @@ mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
|
|||||||
/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
|
/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
|
||||||
/// flattened (i.e., semi-affine expressions not handled yet).
|
/// flattened (i.e., semi-affine expressions not handled yet).
|
||||||
LogicalResult mlir::getFlattenedAffineExprs(
|
LogicalResult mlir::getFlattenedAffineExprs(
|
||||||
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
|
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
|
||||||
FlatAffineConstraints *localVarCst) {
|
FlatAffineConstraints *localVarCst) {
|
||||||
if (map.getNumResults() == 0) {
|
if (map.getNumResults() == 0) {
|
||||||
localVarCst->reset(map.getNumDims(), map.getNumSymbols());
|
localVarCst->reset(map.getNumDims(), map.getNumSymbols());
|
||||||
@ -133,7 +134,7 @@ LogicalResult mlir::getFlattenedAffineExprs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult mlir::getFlattenedAffineExprs(
|
LogicalResult mlir::getFlattenedAffineExprs(
|
||||||
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
|
IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
|
||||||
FlatAffineConstraints *localVarCst) {
|
FlatAffineConstraints *localVarCst) {
|
||||||
if (set.getNumConstraints() == 0) {
|
if (set.getNumConstraints() == 0) {
|
||||||
localVarCst->reset(set.getNumDims(), set.getNumSymbols());
|
localVarCst->reset(set.getNumDims(), set.getNumSymbols());
|
||||||
|
@ -97,7 +97,7 @@ void mlir::buildTripCountMapAndOperands(
|
|||||||
// being an analysis utility, it shouldn't. Replace with a version that just
|
// being an analysis utility, it shouldn't. Replace with a version that just
|
||||||
// works with analysis structures (FlatAffineConstraints) and thus doesn't
|
// works with analysis structures (FlatAffineConstraints) and thus doesn't
|
||||||
// update the IR.
|
// update the IR.
|
||||||
llvm::Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) {
|
Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) {
|
||||||
SmallVector<Value *, 4> operands;
|
SmallVector<Value *, 4> operands;
|
||||||
AffineMap map;
|
AffineMap map;
|
||||||
buildTripCountMapAndOperands(forOp, &map, &operands);
|
buildTripCountMapAndOperands(forOp, &map, &operands);
|
||||||
@ -197,9 +197,9 @@ static bool isAccessIndexInvariant(Value *iv, Value *index) {
|
|||||||
return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
|
return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::DenseSet<Value *>
|
DenseSet<Value *> mlir::getInvariantAccesses(Value *iv,
|
||||||
mlir::getInvariantAccesses(Value *iv, llvm::ArrayRef<Value *> indices) {
|
ArrayRef<Value *> indices) {
|
||||||
llvm::DenseSet<Value *> res;
|
DenseSet<Value *> res;
|
||||||
for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
|
for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
|
||||||
auto *val = indices[idx];
|
auto *val = indices[idx];
|
||||||
if (isAccessIndexInvariant(iv, val)) {
|
if (isAccessIndexInvariant(iv, val)) {
|
||||||
|
4
third_party/mlir/lib/Analysis/OpStats.cpp
vendored
4
third_party/mlir/lib/Analysis/OpStats.cpp
vendored
@ -27,7 +27,7 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
|
struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
|
||||||
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : os(os) {}
|
explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
|
||||||
|
|
||||||
// Prints the resultant operation statistics post iterating over the module.
|
// Prints the resultant operation statistics post iterating over the module.
|
||||||
void runOnModule() override;
|
void runOnModule() override;
|
||||||
@ -37,7 +37,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::StringMap<int64_t> opCount;
|
llvm::StringMap<int64_t> opCount;
|
||||||
llvm::raw_ostream &os;
|
raw_ostream &os;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ static void checkDependences(ArrayRef<Operation *> loadsAndStores) {
|
|||||||
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
|
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
|
||||||
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
|
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
|
||||||
FlatAffineConstraints dependenceConstraints;
|
FlatAffineConstraints dependenceConstraints;
|
||||||
llvm::SmallVector<DependenceComponent, 2> dependenceComponents;
|
SmallVector<DependenceComponent, 2> dependenceComponents;
|
||||||
DependenceResult result = checkMemrefAccessDependence(
|
DependenceResult result = checkMemrefAccessDependence(
|
||||||
srcAccess, dstAccess, d, &dependenceConstraints,
|
srcAccess, dstAccess, d, &dependenceConstraints,
|
||||||
&dependenceComponents);
|
&dependenceComponents);
|
||||||
|
@ -94,7 +94,7 @@ private:
|
|||||||
|
|
||||||
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
|
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
|
||||||
if (funcOp)
|
if (funcOp)
|
||||||
return llvm::cast<LLVMFuncOp>(*funcOp);
|
return cast<LLVMFuncOp>(*funcOp);
|
||||||
|
|
||||||
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
|
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
|
||||||
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
|
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
|
||||||
|
@ -370,13 +370,13 @@ private:
|
|||||||
[&] {
|
[&] {
|
||||||
Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
|
Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, type, shfl, rewriter.getIndexArrayAttr(0));
|
loc, type, shfl, rewriter.getIndexArrayAttr(0));
|
||||||
return llvm::SmallVector<Value *, 1>{
|
return SmallVector<Value *, 1>{
|
||||||
accumFactory(loc, value, shflValue, rewriter)};
|
accumFactory(loc, value, shflValue, rewriter)};
|
||||||
},
|
},
|
||||||
[&] { return llvm::makeArrayRef(value); });
|
[&] { return llvm::makeArrayRef(value); });
|
||||||
value = rewriter.getInsertionBlock()->getArgument(0);
|
value = rewriter.getInsertionBlock()->getArgument(0);
|
||||||
}
|
}
|
||||||
return llvm::SmallVector<Value *, 1>{value};
|
return SmallVector<Value *, 1>{value};
|
||||||
},
|
},
|
||||||
// Generate a reduction over the entire warp. This is a specialization
|
// Generate a reduction over the entire warp. This is a specialization
|
||||||
// of the above reduction with unconditional accumulation.
|
// of the above reduction with unconditional accumulation.
|
||||||
@ -394,7 +394,7 @@ private:
|
|||||||
/*return_value_and_is_valid=*/UnitAttr());
|
/*return_value_and_is_valid=*/UnitAttr());
|
||||||
value = accumFactory(loc, value, shflValue, rewriter);
|
value = accumFactory(loc, value, shflValue, rewriter);
|
||||||
}
|
}
|
||||||
return llvm::SmallVector<Value *, 1>{value};
|
return SmallVector<Value *, 1>{value};
|
||||||
});
|
});
|
||||||
return rewriter.getInsertionBlock()->getArgument(0);
|
return rewriter.getInsertionBlock()->getArgument(0);
|
||||||
}
|
}
|
||||||
|
@ -1603,15 +1603,14 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||||||
|
|
||||||
// If ReturnOp has 0 or 1 operand, create it and return immediately.
|
// If ReturnOp has 0 or 1 operand, create it and return immediately.
|
||||||
if (numArguments == 0) {
|
if (numArguments == 0) {
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::ArrayRef<Value *>(),
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
||||||
llvm::ArrayRef<Block *>(),
|
op, ArrayRef<Value *>(), ArrayRef<Block *>(), op->getAttrs());
|
||||||
op->getAttrs());
|
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
if (numArguments == 1) {
|
if (numArguments == 1) {
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
||||||
op, llvm::ArrayRef<Value *>(operands.front()),
|
op, ArrayRef<Value *>(operands.front()), ArrayRef<Block *>(),
|
||||||
llvm::ArrayRef<Block *>(), op->getAttrs());
|
op->getAttrs());
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1626,9 +1625,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||||||
op->getLoc(), packedType, packed, operands[i],
|
op->getLoc(), packedType, packed, operands[i],
|
||||||
rewriter.getI64ArrayAttr(i));
|
rewriter.getI64ArrayAttr(i));
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::makeArrayRef(packed),
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
||||||
llvm::ArrayRef<Block *>(),
|
op, llvm::makeArrayRef(packed), ArrayRef<Block *>(), op->getAttrs());
|
||||||
op->getAttrs());
|
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1971,7 +1969,7 @@ static void ensureDistinctSuccessors(Block &bb) {
|
|||||||
auto *terminator = bb.getTerminator();
|
auto *terminator = bb.getTerminator();
|
||||||
|
|
||||||
// Find repeated successors with arguments.
|
// Find repeated successors with arguments.
|
||||||
llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions;
|
llvm::SmallDenseMap<Block *, SmallVector<int, 4>> successorPositions;
|
||||||
for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
|
for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
|
||||||
Block *successor = terminator->getSuccessor(i);
|
Block *successor = terminator->getSuccessor(i);
|
||||||
// Blocks with no arguments are safe even if they appear multiple times
|
// Blocks with no arguments are safe even if they appear multiple times
|
||||||
|
@ -155,7 +155,7 @@ void coalesceCopy(TransferOpTy transfer,
|
|||||||
/// Emits remote memory accesses that are clipped to the boundaries of the
|
/// Emits remote memory accesses that are clipped to the boundaries of the
|
||||||
/// MemRef.
|
/// MemRef.
|
||||||
template <typename TransferOpTy>
|
template <typename TransferOpTy>
|
||||||
llvm::SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer,
|
SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer,
|
||||||
edsc::MemRefView &view,
|
edsc::MemRefView &view,
|
||||||
ArrayRef<edsc::IndexHandle> ivs) {
|
ArrayRef<edsc::IndexHandle> ivs) {
|
||||||
using namespace mlir::edsc;
|
using namespace mlir::edsc;
|
||||||
@ -163,8 +163,8 @@ llvm::SmallVector<edsc::ValueHandle, 8> clip(TransferOpTy transfer,
|
|||||||
using edsc::intrinsics::select;
|
using edsc::intrinsics::select;
|
||||||
|
|
||||||
IndexHandle zero(index_t(0)), one(index_t(1));
|
IndexHandle zero(index_t(0)), one(index_t(1));
|
||||||
llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices());
|
SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices());
|
||||||
llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
|
SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
|
||||||
memRefAccess.size(), edsc::IndexHandle());
|
memRefAccess.size(), edsc::IndexHandle());
|
||||||
|
|
||||||
// Indices accessing to remote memory are clipped and their expressions are
|
// Indices accessing to remote memory are clipped and their expressions are
|
||||||
|
@ -616,9 +616,8 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
|
|||||||
// A symbol may appear as a dim in affine.apply operations. This function
|
// A symbol may appear as a dim in affine.apply operations. This function
|
||||||
// canonicalizes dims that are valid symbols into actual symbols.
|
// canonicalizes dims that are valid symbols into actual symbols.
|
||||||
template <class MapOrSet>
|
template <class MapOrSet>
|
||||||
static void
|
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
|
||||||
canonicalizePromotedSymbols(MapOrSet *mapOrSet,
|
SmallVectorImpl<Value *> *operands) {
|
||||||
llvm::SmallVectorImpl<Value *> *operands) {
|
|
||||||
if (!mapOrSet || operands->empty())
|
if (!mapOrSet || operands->empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@ -662,7 +661,7 @@ canonicalizePromotedSymbols(MapOrSet *mapOrSet,
|
|||||||
template <class MapOrSet>
|
template <class MapOrSet>
|
||||||
static void
|
static void
|
||||||
canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
|
canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
|
||||||
llvm::SmallVectorImpl<Value *> *operands) {
|
SmallVectorImpl<Value *> *operands) {
|
||||||
static_assert(std::is_same<MapOrSet, AffineMap>::value ||
|
static_assert(std::is_same<MapOrSet, AffineMap>::value ||
|
||||||
std::is_same<MapOrSet, IntegerSet>::value,
|
std::is_same<MapOrSet, IntegerSet>::value,
|
||||||
"Argument must be either of AffineMap or IntegerSet type");
|
"Argument must be either of AffineMap or IntegerSet type");
|
||||||
@ -738,13 +737,13 @@ canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
|
|||||||
*operands = resultOperands;
|
*operands = resultOperands;
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::canonicalizeMapAndOperands(
|
void mlir::canonicalizeMapAndOperands(AffineMap *map,
|
||||||
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
|
SmallVectorImpl<Value *> *operands) {
|
||||||
canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
|
canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::canonicalizeSetAndOperands(
|
void mlir::canonicalizeSetAndOperands(IntegerSet *set,
|
||||||
IntegerSet *set, llvm::SmallVectorImpl<Value *> *operands) {
|
SmallVectorImpl<Value *> *operands) {
|
||||||
canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
|
canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ inline quant::UniformQuantizedType getUniformElementType(Type t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline bool hasStorageBitWidth(quant::QuantizedType t,
|
inline bool hasStorageBitWidth(quant::QuantizedType t,
|
||||||
llvm::ArrayRef<unsigned> checkWidths) {
|
ArrayRef<unsigned> checkWidths) {
|
||||||
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
|
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
|
||||||
for (unsigned checkWidth : checkWidths) {
|
for (unsigned checkWidth : checkWidths) {
|
||||||
if (w == checkWidth)
|
if (w == checkWidth)
|
||||||
|
@ -237,7 +237,7 @@ KernelDim3 LaunchOp::getBlockSizeOperandValues() {
|
|||||||
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
|
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
|
iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
|
||||||
auto args = body().getBlocks().front().getArguments();
|
auto args = body().getBlocks().front().getArguments();
|
||||||
return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
|
return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
|
|||||||
gpu::LaunchFuncOp launch) {
|
gpu::LaunchFuncOp launch) {
|
||||||
OpBuilder kernelBuilder(kernelFunc.getBody());
|
OpBuilder kernelBuilder(kernelFunc.getBody());
|
||||||
auto &firstBlock = kernelFunc.getBody().front();
|
auto &firstBlock = kernelFunc.getBody().front();
|
||||||
llvm::SmallVector<Value *, 8> newLaunchArgs;
|
SmallVector<Value *, 8> newLaunchArgs;
|
||||||
BlockAndValueMapping map;
|
BlockAndValueMapping map;
|
||||||
for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) {
|
for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) {
|
||||||
map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
|
map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
|
||||||
@ -195,7 +195,7 @@ private:
|
|||||||
SymbolTable symbolTable(kernelModule);
|
SymbolTable symbolTable(kernelModule);
|
||||||
symbolTable.insert(kernelFunc);
|
symbolTable.insert(kernelFunc);
|
||||||
|
|
||||||
llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
|
SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
|
||||||
while (!symbolDefWorklist.empty()) {
|
while (!symbolDefWorklist.empty()) {
|
||||||
if (Optional<SymbolTable::UseRange> symbolUses =
|
if (Optional<SymbolTable::UseRange> symbolUses =
|
||||||
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
|
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
|
||||||
|
@ -1227,7 +1227,7 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
|
|||||||
|
|
||||||
auto *body = result.addRegion();
|
auto *body = result.addRegion();
|
||||||
return parser.parseOptionalRegion(
|
return parser.parseOptionalRegion(
|
||||||
*body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes);
|
*body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print the LLVMFuncOp. Collects argument and result types and passes them to
|
// Print the LLVMFuncOp. Collects argument and result types and passes them to
|
||||||
@ -1499,7 +1499,7 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
|
|||||||
/// Get an LLVMType with an llvm type that may cause changes to the underlying
|
/// Get an LLVMType with an llvm type that may cause changes to the underlying
|
||||||
/// llvm context when constructed.
|
/// llvm context when constructed.
|
||||||
LLVMType LLVMType::getLocked(LLVMDialect *dialect,
|
LLVMType LLVMType::getLocked(LLVMDialect *dialect,
|
||||||
llvm::function_ref<llvm::Type *()> typeBuilder) {
|
function_ref<llvm::Type *()> typeBuilder) {
|
||||||
// Lock access to the llvm context and build the type.
|
// Lock access to the llvm context and build the type.
|
||||||
llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
|
llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
|
||||||
return get(dialect->getContext(), typeBuilder());
|
return get(dialect->getContext(), typeBuilder());
|
||||||
|
@ -44,7 +44,7 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
|
|||||||
Operation *mlir::edsc::makeLinalgGenericOp(
|
Operation *mlir::edsc::makeLinalgGenericOp(
|
||||||
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
||||||
ArrayRef<StructuredIndexed> outputs,
|
ArrayRef<StructuredIndexed> outputs,
|
||||||
llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder,
|
function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder,
|
||||||
ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) {
|
ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) {
|
||||||
auto &builder = edsc::ScopedContext::getBuilder();
|
auto &builder = edsc::ScopedContext::getBuilder();
|
||||||
auto *ctx = builder.getContext();
|
auto *ctx = builder.getContext();
|
||||||
|
@ -632,7 +632,7 @@ namespace linalg {
|
|||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
static AffineMap extractOrIdentityMap(llvm::Optional<AffineMap> maybeMap,
|
static AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap,
|
||||||
unsigned rank, MLIRContext *context) {
|
unsigned rank, MLIRContext *context) {
|
||||||
if (maybeMap)
|
if (maybeMap)
|
||||||
return maybeMap.getValue();
|
return maybeMap.getValue();
|
||||||
|
@ -100,7 +100,7 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
|
|||||||
|
|
||||||
bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
|
bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
|
||||||
Operation *consumerOp, Value *consumedView,
|
Operation *consumerOp, Value *consumedView,
|
||||||
llvm::function_ref<bool(Operation *)> isaOpType) {
|
function_ref<bool(Operation *)> isaOpType) {
|
||||||
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
|
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
|
||||||
if (!consumer)
|
if (!consumer)
|
||||||
return false;
|
return false;
|
||||||
|
@ -315,7 +315,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||||
OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
|
OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
|
||||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||||
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
|
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
|
||||||
@ -389,7 +389,7 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
|||||||
return TiledLinalgOp{res, loops};
|
return TiledLinalgOp{res, loops};
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||||
if (tileSizes.empty())
|
if (tileSizes.empty())
|
||||||
|
17
third_party/mlir/lib/Dialect/SDBM/SDBM.cpp
vendored
17
third_party/mlir/lib/Dialect/SDBM/SDBM.cpp
vendored
@ -88,11 +88,11 @@ namespace {
|
|||||||
struct SDBMBuilderResult {
|
struct SDBMBuilderResult {
|
||||||
// Positions in the matrix of the variables taken with the "+" sign in the
|
// Positions in the matrix of the variables taken with the "+" sign in the
|
||||||
// difference expression, 0 if it is a constant rather than a variable.
|
// difference expression, 0 if it is a constant rather than a variable.
|
||||||
llvm::SmallVector<unsigned, 2> positivePos;
|
SmallVector<unsigned, 2> positivePos;
|
||||||
|
|
||||||
// Positions in the matrix of the variables taken with the "-" sign in the
|
// Positions in the matrix of the variables taken with the "-" sign in the
|
||||||
// difference expression, 0 if it is a constant rather than a variable.
|
// difference expression, 0 if it is a constant rather than a variable.
|
||||||
llvm::SmallVector<unsigned, 2> negativePos;
|
SmallVector<unsigned, 2> negativePos;
|
||||||
|
|
||||||
// Constant value in the difference expression.
|
// Constant value in the difference expression.
|
||||||
int64_t value = 0;
|
int64_t value = 0;
|
||||||
@ -184,13 +184,12 @@ public:
|
|||||||
return lhs;
|
return lhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMBuilder(llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>>
|
SDBMBuilder(DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe,
|
||||||
&pointExprToStripe,
|
function_ref<unsigned(SDBMInputExpr)> callback)
|
||||||
llvm::function_ref<unsigned(SDBMInputExpr)> callback)
|
|
||||||
: pointExprToStripe(pointExprToStripe), linearPosition(callback) {}
|
: pointExprToStripe(pointExprToStripe), linearPosition(callback) {}
|
||||||
|
|
||||||
llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> &pointExprToStripe;
|
DenseMap<SDBMExpr, SmallVector<unsigned, 2>> &pointExprToStripe;
|
||||||
llvm::function_ref<unsigned(SDBMInputExpr)> linearPosition;
|
function_ref<unsigned(SDBMInputExpr)> linearPosition;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -239,7 +238,7 @@ SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
|
|||||||
// expression. Keep track of those in pointExprToStripe.
|
// expression. Keep track of those in pointExprToStripe.
|
||||||
// There may also be multiple stripe expressions equal to the same variable.
|
// There may also be multiple stripe expressions equal to the same variable.
|
||||||
// Introduce a temporary variable for each of those.
|
// Introduce a temporary variable for each of those.
|
||||||
llvm::DenseMap<SDBMExpr, llvm::SmallVector<unsigned, 2>> pointExprToStripe;
|
DenseMap<SDBMExpr, SmallVector<unsigned, 2>> pointExprToStripe;
|
||||||
unsigned numTemporaries = 0;
|
unsigned numTemporaries = 0;
|
||||||
|
|
||||||
auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe,
|
auto updateStripePointMaps = [&numTemporaries, &result, &pointExprToStripe,
|
||||||
@ -512,7 +511,7 @@ void SDBM::getSDBMExpressions(SDBMDialect *dialect,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SDBM::print(llvm::raw_ostream &os) {
|
void SDBM::print(raw_ostream &os) {
|
||||||
unsigned numVariables = getNumVariables();
|
unsigned numVariables = getNumVariables();
|
||||||
|
|
||||||
// Helper function that prints the name of the variable given its linearized
|
// Helper function that prints the name of the variable given its linearized
|
||||||
|
@ -89,7 +89,7 @@ public:
|
|||||||
: subExprs(exprs.begin(), exprs.end()) {}
|
: subExprs(exprs.begin(), exprs.end()) {}
|
||||||
AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
|
AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
|
||||||
: subExprs({a, b}) {}
|
: subExprs({a, b}) {}
|
||||||
llvm::SmallVector<AffineExprMatcher, 0> subExprs;
|
SmallVector<AffineExprMatcher, 0> subExprs;
|
||||||
AffineExpr matched;
|
AffineExpr matched;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -311,7 +311,7 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
|
|||||||
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
||||||
template <typename Result>
|
template <typename Result>
|
||||||
Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated,
|
Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated,
|
||||||
llvm::function_ref<Result(SDBMDirectExpr)> builder) {
|
function_ref<Result(SDBMDirectExpr)> builder) {
|
||||||
SDBMDialect *dialect = expr.getDialect();
|
SDBMDialect *dialect = expr.getDialect();
|
||||||
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
||||||
if (negated)
|
if (negated)
|
||||||
|
@ -33,10 +33,9 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
|
|||||||
return structType;
|
return structType;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<Type, 4> memberTypes;
|
SmallVector<Type, 4> memberTypes;
|
||||||
llvm::SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo;
|
SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo;
|
||||||
llvm::SmallVector<spirv::StructType::MemberDecorationInfo, 4>
|
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
|
||||||
memberDecorations;
|
|
||||||
|
|
||||||
VulkanLayoutUtils::Size structMemberOffset = 0;
|
VulkanLayoutUtils::Size structMemberOffset = 0;
|
||||||
VulkanLayoutUtils::Size maxMemberAlignment = 1;
|
VulkanLayoutUtils::Size maxMemberAlignment = 1;
|
||||||
|
@ -149,7 +149,7 @@ Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
|
|||||||
DialectAsmParser &parser);
|
DialectAsmParser &parser);
|
||||||
|
|
||||||
static bool isValidSPIRVIntType(IntegerType type) {
|
static bool isValidSPIRVIntType(IntegerType type) {
|
||||||
return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
|
return llvm::is_contained(ArrayRef<unsigned>({1, 8, 16, 32, 64}),
|
||||||
type.getWidth());
|
type.getWidth());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ static LogicalResult extractValueFromConstOp(Operation *op,
|
|||||||
template <typename Ty>
|
template <typename Ty>
|
||||||
static ArrayAttr
|
static ArrayAttr
|
||||||
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
|
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
|
||||||
llvm::function_ref<StringRef(Ty)> stringifyFn) {
|
function_ref<StringRef(Ty)> stringifyFn) {
|
||||||
if (enumValues.empty()) {
|
if (enumValues.empty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -399,7 +399,7 @@ static unsigned getBitWidth(Type type) {
|
|||||||
/// emits errors with the given loc on failure.
|
/// emits errors with the given loc on failure.
|
||||||
static Type
|
static Type
|
||||||
getElementType(Type type, ArrayRef<int32_t> indices,
|
getElementType(Type type, ArrayRef<int32_t> indices,
|
||||||
llvm::function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
||||||
if (indices.empty()) {
|
if (indices.empty()) {
|
||||||
emitErrorFn("expected at least one index for spv.CompositeExtract");
|
emitErrorFn("expected at least one index for spv.CompositeExtract");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -423,7 +423,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
|
|||||||
|
|
||||||
static Type
|
static Type
|
||||||
getElementType(Type type, Attribute indices,
|
getElementType(Type type, Attribute indices,
|
||||||
llvm::function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
|
||||||
auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
|
auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
|
||||||
if (!indicesArrayAttr) {
|
if (!indicesArrayAttr) {
|
||||||
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
|
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
|
||||||
@ -2317,7 +2317,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
|||||||
auto &op = *moduleOp.getOperation();
|
auto &op = *moduleOp.getOperation();
|
||||||
auto *dialect = op.getDialect();
|
auto *dialect = op.getDialect();
|
||||||
auto &body = op.getRegion(0).front();
|
auto &body = op.getRegion(0).front();
|
||||||
llvm::DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
|
DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
|
||||||
entryPoints;
|
entryPoints;
|
||||||
SymbolTable table(moduleOp);
|
SymbolTable table(moduleOp);
|
||||||
|
|
||||||
|
@ -2366,7 +2366,7 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
|
|||||||
|
|
||||||
auto functionName = getFunctionSymbol(functionID);
|
auto functionName = getFunctionSymbol(functionID);
|
||||||
|
|
||||||
llvm::SmallVector<Value *, 4> arguments;
|
SmallVector<Value *, 4> arguments;
|
||||||
for (auto operand : llvm::drop_begin(operands, 3)) {
|
for (auto operand : llvm::drop_begin(operands, 3)) {
|
||||||
auto *value = getValue(operand);
|
auto *value = getValue(operand);
|
||||||
if (!value) {
|
if (!value) {
|
||||||
|
@ -69,7 +69,7 @@ static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
|
|||||||
/// serialization of the merge block and the continue block, if exists, until
|
/// serialization of the merge block and the continue block, if exists, until
|
||||||
/// after all other blocks have been processed.
|
/// after all other blocks have been processed.
|
||||||
static LogicalResult visitInPrettyBlockOrder(
|
static LogicalResult visitInPrettyBlockOrder(
|
||||||
Block *headerBlock, llvm::function_ref<LogicalResult(Block *)> blockHandler,
|
Block *headerBlock, function_ref<LogicalResult(Block *)> blockHandler,
|
||||||
bool skipHeader = false, ArrayRef<Block *> skipBlocks = {}) {
|
bool skipHeader = false, ArrayRef<Block *> skipBlocks = {}) {
|
||||||
llvm::df_iterator_default_set<Block *, 4> doneBlocks;
|
llvm::df_iterator_default_set<Block *, 4> doneBlocks;
|
||||||
doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
|
doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
|
||||||
@ -301,7 +301,7 @@ private:
|
|||||||
/// instruction if this is a SPIR-V selection/loop header block.
|
/// instruction if this is a SPIR-V selection/loop header block.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
processBlock(Block *block, bool omitLabel = false,
|
processBlock(Block *block, bool omitLabel = false,
|
||||||
llvm::function_ref<void()> actionBeforeTerminator = nullptr);
|
function_ref<void()> actionBeforeTerminator = nullptr);
|
||||||
|
|
||||||
/// Emits OpPhi instructions for the given block if it has block arguments.
|
/// Emits OpPhi instructions for the given block if it has block arguments.
|
||||||
LogicalResult emitPhiForBlockArguments(Block *block);
|
LogicalResult emitPhiForBlockArguments(Block *block);
|
||||||
@ -457,7 +457,7 @@ private:
|
|||||||
/// placed inside `functions`) here. And then after emitting all blocks, we
|
/// placed inside `functions`) here. And then after emitting all blocks, we
|
||||||
/// replace the dummy <id> 0 with the real result <id> by overwriting
|
/// replace the dummy <id> 0 with the real result <id> by overwriting
|
||||||
/// `functions[offset]`.
|
/// `functions[offset]`.
|
||||||
DenseMap<Value *, llvm::SmallVector<size_t, 1>> deferredPhiValues;
|
DenseMap<Value *, SmallVector<size_t, 1>> deferredPhiValues;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -1341,7 +1341,7 @@ uint32_t Serializer::getOrCreateBlockID(Block *block) {
|
|||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
Serializer::processBlock(Block *block, bool omitLabel,
|
Serializer::processBlock(Block *block, bool omitLabel,
|
||||||
llvm::function_ref<void()> actionBeforeTerminator) {
|
function_ref<void()> actionBeforeTerminator) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
|
LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
|
||||||
LLVM_DEBUG(block->print(llvm::dbgs()));
|
LLVM_DEBUG(block->print(llvm::dbgs()));
|
||||||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||||
@ -1773,7 +1773,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
|||||||
auto funcName = op.callee();
|
auto funcName = op.callee();
|
||||||
uint32_t resTypeID = 0;
|
uint32_t resTypeID = 0;
|
||||||
|
|
||||||
llvm::SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
||||||
if (failed(processType(op.getLoc(),
|
if (failed(processType(op.getLoc(),
|
||||||
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
|
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
|
||||||
resTypeID))) {
|
resTypeID))) {
|
||||||
|
@ -80,7 +80,7 @@ static TranslateToMLIRRegistration fromBinary(
|
|||||||
// Serialization registration
|
// Serialization registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult serializeModule(ModuleOp module, llvm::raw_ostream &output) {
|
LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
|
||||||
if (!module)
|
if (!module)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ LogicalResult serializeModule(ModuleOp module, llvm::raw_ostream &output) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static TranslateFromMLIRRegistration
|
static TranslateFromMLIRRegistration
|
||||||
toBinary("serialize-spirv", [](ModuleOp module, llvm::raw_ostream &output) {
|
toBinary("serialize-spirv", [](ModuleOp module, raw_ostream &output) {
|
||||||
return serializeModule(module, output);
|
return serializeModule(module, output);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -113,8 +113,8 @@ static TranslateFromMLIRRegistration
|
|||||||
// Round-trip registration
|
// Round-trip registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
|
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, raw_ostream &output,
|
||||||
llvm::raw_ostream &output, MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
// Parse an MLIR module from the source manager.
|
// Parse an MLIR module from the source manager.
|
||||||
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||||
if (!srcModule)
|
if (!srcModule)
|
||||||
@ -147,9 +147,8 @@ LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
|
|||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static TranslateRegistration
|
static TranslateRegistration roundtrip(
|
||||||
roundtrip("test-spirv-roundtrip",
|
"test-spirv-roundtrip",
|
||||||
[](llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output,
|
[](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) {
|
||||||
MLIRContext *context) {
|
|
||||||
return roundTripModule(sourceMgr, output, context);
|
return roundTripModule(sourceMgr, output, context);
|
||||||
});
|
});
|
||||||
|
15
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
15
third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
vendored
@ -2297,7 +2297,7 @@ static void print(OpAsmPrinter &p, ViewOp op) {
|
|||||||
|
|
||||||
Value *ViewOp::getDynamicOffset() {
|
Value *ViewOp::getDynamicOffset() {
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
llvm::SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
auto result =
|
auto result =
|
||||||
succeeded(mlir::getStridesAndOffset(getType(), strides, offset));
|
succeeded(mlir::getStridesAndOffset(getType(), strides, offset));
|
||||||
assert(result);
|
assert(result);
|
||||||
@ -2341,7 +2341,7 @@ static LogicalResult verify(ViewOp op) {
|
|||||||
|
|
||||||
// Verify that the result memref type has a strided layout map.
|
// Verify that the result memref type has a strided layout map.
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
llvm::SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
if (failed(getStridesAndOffset(viewType, strides, offset)))
|
if (failed(getStridesAndOffset(viewType, strides, offset)))
|
||||||
return op.emitError("result type ") << viewType << " is not strided";
|
return op.emitError("result type ") << viewType << " is not strided";
|
||||||
|
|
||||||
@ -2383,7 +2383,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||||||
|
|
||||||
// Get offset from old memref view type 'memRefType'.
|
// Get offset from old memref view type 'memRefType'.
|
||||||
int64_t oldOffset;
|
int64_t oldOffset;
|
||||||
llvm::SmallVector<int64_t, 4> oldStrides;
|
SmallVector<int64_t, 4> oldStrides;
|
||||||
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
|
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
|
||||||
@ -2585,13 +2585,13 @@ static LogicalResult verify(SubViewOp op) {
|
|||||||
|
|
||||||
// Verify that the base memref type has a strided layout map.
|
// Verify that the base memref type has a strided layout map.
|
||||||
int64_t baseOffset;
|
int64_t baseOffset;
|
||||||
llvm::SmallVector<int64_t, 4> baseStrides;
|
SmallVector<int64_t, 4> baseStrides;
|
||||||
if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
|
if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
|
||||||
return op.emitError("base type ") << subViewType << " is not strided";
|
return op.emitError("base type ") << subViewType << " is not strided";
|
||||||
|
|
||||||
// Verify that the result memref type has a strided layout map.
|
// Verify that the result memref type has a strided layout map.
|
||||||
int64_t subViewOffset;
|
int64_t subViewOffset;
|
||||||
llvm::SmallVector<int64_t, 4> subViewStrides;
|
SmallVector<int64_t, 4> subViewStrides;
|
||||||
if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
|
if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
|
||||||
return op.emitError("result type ") << subViewType << " is not strided";
|
return op.emitError("result type ") << subViewType << " is not strided";
|
||||||
|
|
||||||
@ -2677,8 +2677,7 @@ static LogicalResult verify(SubViewOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os,
|
raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
|
||||||
SubViewOp::Range &range) {
|
|
||||||
return os << "range " << *range.offset << ":" << *range.size << ":"
|
return os << "range " << *range.offset << ":" << *range.size << ":"
|
||||||
<< *range.stride;
|
<< *range.stride;
|
||||||
}
|
}
|
||||||
@ -2734,7 +2733,7 @@ static bool hasConstantOffsetSizesAndStrides(MemRefType memrefType) {
|
|||||||
return false;
|
return false;
|
||||||
// Get offset and strides.
|
// Get offset and strides.
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
llvm::SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
if (failed(getStridesAndOffset(memrefType, strides, offset)))
|
if (failed(getStridesAndOffset(memrefType, strides, offset)))
|
||||||
return false;
|
return false;
|
||||||
// Return 'false' if any of offset or strides is dynamic.
|
// Return 'false' if any of offset or strides is dynamic.
|
||||||
|
5
third_party/mlir/lib/Dialect/Traits.cpp
vendored
5
third_party/mlir/lib/Dialect/Traits.cpp
vendored
@ -112,8 +112,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
|||||||
|
|
||||||
// Returns the type kind if the given type is a vector or ranked tensor type.
|
// Returns the type kind if the given type is a vector or ranked tensor type.
|
||||||
// Returns llvm::None otherwise.
|
// Returns llvm::None otherwise.
|
||||||
auto getCompositeTypeKind =
|
auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
|
||||||
[](Type type) -> llvm::Optional<StandardTypes::Kind> {
|
|
||||||
if (type.isa<VectorType>() || type.isa<RankedTensorType>())
|
if (type.isa<VectorType>() || type.isa<RankedTensorType>())
|
||||||
return static_cast<StandardTypes::Kind>(type.getKind());
|
return static_cast<StandardTypes::Kind>(type.getKind());
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
@ -122,7 +121,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
|||||||
// Make sure the composite type, if has, is consistent.
|
// Make sure the composite type, if has, is consistent.
|
||||||
auto compositeKind1 = getCompositeTypeKind(type1);
|
auto compositeKind1 = getCompositeTypeKind(type1);
|
||||||
auto compositeKind2 = getCompositeTypeKind(type2);
|
auto compositeKind2 = getCompositeTypeKind(type2);
|
||||||
llvm::Optional<StandardTypes::Kind> resultCompositeKind;
|
Optional<StandardTypes::Kind> resultCompositeKind;
|
||||||
|
|
||||||
if (compositeKind1 && compositeKind2) {
|
if (compositeKind1 && compositeKind2) {
|
||||||
// Disallow mixing vector and tensor.
|
// Disallow mixing vector and tensor.
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user