Adds support for unrolling single-result vector operations with iterator type lists and indexing maps to a target vector size.
Adds unit tests for unrolling the vector ContractionOp with different iteration orders. PiperOrigin-RevId: 283747503 Change-Id: Ib7e4f757d15760cd89fc09fedc49b7a2dc2a1fe6
This commit is contained in:
parent
222977dffd
commit
8f661bace2
@ -157,6 +157,18 @@ def Vector_ContractionOp :
|
|||||||
static StringRef getParallelIteratorTypeName() {
|
static StringRef getParallelIteratorTypeName() {
|
||||||
return "parallel";
|
return "parallel";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the bounds of each dimension in the iteration space spanned
|
||||||
|
// by the iterator types of this operation.
|
||||||
|
void getIterationBounds(SmallVectorImpl<int64_t> &iterationBounds);
|
||||||
|
|
||||||
|
// Returns a list of index maps, where there is a list entry for each
|
||||||
|
// op indexing map attribute (i.e. one for each input and output, with
|
||||||
|
// the output listed last). Each index map, maps from this operations
|
||||||
|
// iteration space, to vector dimensions of the maps input/output.
|
||||||
|
void getIterationIndexMap(
|
||||||
|
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap);
|
||||||
|
|
||||||
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
||||||
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
||||||
}];
|
}];
|
||||||
|
@ -40,4 +40,9 @@ def : Pat<(AddFOp:$op_results $a, $b),
|
|||||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||||
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
||||||
|
|
||||||
|
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
|
||||||
|
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
|
||||||
|
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
|
||||||
|
[(Constraint<HasShape<[4, 4]>> $c)]>;
|
||||||
|
|
||||||
#endif // VECTOR_TRANSFORMS
|
#endif // VECTOR_TRANSFORMS
|
||||||
|
@ -271,6 +271,44 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
|
|||||||
return dimMap;
|
return dimMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ContractionOp::getIterationBounds(
|
||||||
|
SmallVectorImpl<int64_t> &iterationBounds) {
|
||||||
|
auto lhsShape = getLhsType().getShape();
|
||||||
|
auto resShape = getResultType().getShape();
|
||||||
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||||
|
SmallVector<int64_t, 2> iterationShape;
|
||||||
|
for (auto it : llvm::enumerate(iterator_types())) {
|
||||||
|
// Search lhs/rhs map results for 'targetExpr'.
|
||||||
|
auto targetExpr = getAffineDimExpr(it.index(), getContext());
|
||||||
|
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
||||||
|
if (iteratorTypeName == getReductionIteratorTypeName()) {
|
||||||
|
// Get reduction dim size from lhs shape (same size in rhsShape).
|
||||||
|
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
|
||||||
|
assert(lhsDimIndex >= 0);
|
||||||
|
iterationBounds.push_back(lhsShape[lhsDimIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Get parallel dimension size from result shape.
|
||||||
|
int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
|
||||||
|
assert(resDimIndex >= 0);
|
||||||
|
iterationBounds.push_back(resShape[resDimIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContractionOp::getIterationIndexMap(
|
||||||
|
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
|
||||||
|
unsigned numMaps = indexing_maps().getValue().size();
|
||||||
|
iterationIndexMap.resize(numMaps);
|
||||||
|
for (auto it : llvm::enumerate(indexing_maps())) {
|
||||||
|
auto index = it.index();
|
||||||
|
auto map = it.value().cast<AffineMapAttr>().getValue();
|
||||||
|
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||||
|
auto dim = map.getResult(i).cast<AffineDimExpr>();
|
||||||
|
iterationIndexMap[index][dim.getPosition()] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
||||||
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||||
return getDimMap(indexingMaps, iterator_types(),
|
return getDimMap(indexingMaps, iterator_types(),
|
||||||
|
@ -77,6 +77,15 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
|
||||||
|
static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
|
||||||
|
assert(offsets.size() == basis.size());
|
||||||
|
int64_t linearIndex = 0;
|
||||||
|
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
|
||||||
|
linearIndex += offsets[idx] * basis[idx];
|
||||||
|
return linearIndex;
|
||||||
|
}
|
||||||
|
|
||||||
/// Given a shape with sizes greater than 0 along all dimensions, returns the
|
/// Given a shape with sizes greater than 0 along all dimensions, returns the
|
||||||
/// delinearized components of linearIndex along shape.
|
/// delinearized components of linearIndex along shape.
|
||||||
static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
|
static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
|
||||||
@ -151,9 +160,9 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
|||||||
Location loc, Operation *op,
|
Location loc, Operation *op,
|
||||||
ArrayRef<Value *> operands,
|
ArrayRef<Value *> operands,
|
||||||
ArrayRef<Type> resultTypes) {
|
ArrayRef<Type> resultTypes) {
|
||||||
OperationState *res = new OperationState(loc, op->getName().getStringRef(),
|
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
|
||||||
operands, resultTypes, {});
|
op->getAttrs());
|
||||||
return builder.createOperation(*res);
|
return builder.createOperation(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function for Tablegen.
|
// Helper function for Tablegen.
|
||||||
@ -164,6 +173,223 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
|
|||||||
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
|
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
||||||
|
VectorType vt) {
|
||||||
|
auto t = vt.getElementType();
|
||||||
|
Value *f = nullptr;
|
||||||
|
if (t.isBF16() || t.isF16())
|
||||||
|
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
|
||||||
|
else if (t.isF32())
|
||||||
|
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
else if (t.isF64())
|
||||||
|
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
|
||||||
|
if (f)
|
||||||
|
return rewriter.create<SplatOp>(loc, vt, f);
|
||||||
|
llvm_unreachable("Unsupported type in `makeSplatZero`");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]'
|
||||||
|
// for each index 'i' in inputElements with a valid mapping in 'indexMap'.
|
||||||
|
static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
|
||||||
|
ArrayRef<int64_t> inputElements,
|
||||||
|
SmallVectorImpl<int64_t> &resultElements) {
|
||||||
|
assert(indexMap.size() == resultElements.size());
|
||||||
|
assert(inputElements.size() >= resultElements.size());
|
||||||
|
for (unsigned i = 0, e = inputElements.size(); i < e; ++i) {
|
||||||
|
auto it = indexMap.find(i);
|
||||||
|
if (it != indexMap.end())
|
||||||
|
resultElements[it->second] = inputElements[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnrolledOperandState aggregates per-operand state required for op unrolling.
|
||||||
|
struct UnrolledOperandState {
|
||||||
|
Value *operand;
|
||||||
|
SmallVector<int64_t, 4> unrolledShape;
|
||||||
|
SmallVector<int64_t, 4> unrollFactors;
|
||||||
|
SmallVector<int64_t, 8> basis;
|
||||||
|
int64_t numInstances;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Populates 'state' with unrolled shape, unroll factors, basis and
|
||||||
|
// num unrolled instances for 'operand'.
|
||||||
|
static void getUnrolledOperandState(Value *operand,
|
||||||
|
const DenseMap<int64_t, int64_t> &indexMap,
|
||||||
|
ArrayRef<int64_t> targetShape,
|
||||||
|
UnrolledOperandState &state) {
|
||||||
|
auto vectorType = operand->getType().cast<VectorType>();
|
||||||
|
state.operand = operand;
|
||||||
|
// Compute unrolled shape of 'operand'.
|
||||||
|
state.unrolledShape.resize(vectorType.getRank());
|
||||||
|
getMappedElements(indexMap, targetShape, state.unrolledShape);
|
||||||
|
// Compute unroll factors for unrolled shape.
|
||||||
|
auto maybeUnrollFactors =
|
||||||
|
shapeRatio(vectorType.getShape(), state.unrolledShape);
|
||||||
|
assert(maybeUnrollFactors.hasValue());
|
||||||
|
state.unrollFactors = *maybeUnrollFactors;
|
||||||
|
// Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
|
||||||
|
state.basis = computeStrides(state.unrollFactors);
|
||||||
|
state.numInstances = computeMaxLinearIndex(state.unrollFactors);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes and returns the linear index of the unrolled vector at
|
||||||
|
// 'vectorOffsets' within the vector operand represented by 'state'.
|
||||||
|
static int64_t
|
||||||
|
getUnrolledOperandLinearIndex(UnrolledOperandState &state,
|
||||||
|
ArrayRef<int64_t> vectorOffsets,
|
||||||
|
DenseMap<int64_t, int64_t> &indexMap) {
|
||||||
|
// Compute operand offsets.
|
||||||
|
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||||
|
getMappedElements(indexMap, vectorOffsets, sliceOffsets);
|
||||||
|
// Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'.
|
||||||
|
return linearize(sliceOffsets, state.basis);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns an unrolled vector at 'vectorOffsets' within the vector operand
|
||||||
|
// represented by 'state'. The value is created if not present in 'cache'.
|
||||||
|
static Value *getOrCreateUnrolledOperandSlice(
|
||||||
|
Location loc, UnrolledOperandState &state, ArrayRef<int64_t> vectorOffsets,
|
||||||
|
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
|
||||||
|
SmallVectorImpl<Value *> &cache, PatternRewriter &builder) {
|
||||||
|
// Compute operand offsets.
|
||||||
|
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||||
|
getMappedElements(indexMap, offsets, sliceOffsets);
|
||||||
|
// TODO(b/144845578) Support non-1 strides.
|
||||||
|
SmallVector<int64_t, 4> sliceStrides(state.unrolledShape.size(), 1);
|
||||||
|
// Compute linear index of 'sliceOffsets' w.r.t 'state.basis'.
|
||||||
|
int64_t sliceLinearIndex =
|
||||||
|
getUnrolledOperandLinearIndex(state, vectorOffsets, indexMap);
|
||||||
|
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
|
||||||
|
auto *operandSlice = cache[sliceLinearIndex];
|
||||||
|
if (operandSlice == nullptr) {
|
||||||
|
// Initialize 'cache' with slice from 'state.operand'.
|
||||||
|
operandSlice = builder.create<vector::StridedSliceOp>(
|
||||||
|
loc, state.operand, sliceOffsets, state.unrolledShape, sliceStrides);
|
||||||
|
// Store value back to 'cache'.
|
||||||
|
cache[sliceLinearIndex] = operandSlice;
|
||||||
|
}
|
||||||
|
return operandSlice;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// unrollSingleResultStructuredOp
|
||||||
|
//
|
||||||
|
// Returns a value representing the result of structured operation 'op'
|
||||||
|
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
|
||||||
|
// An iteration space index map argument 'iterationIndexMapList' must be
|
||||||
|
// specified, with a map for each structured op input and a single map for the
|
||||||
|
// single result. The last map in the list must be the single result map.
|
||||||
|
// Extra operands can be passed to unrolled instances of 'op' using the
|
||||||
|
// 'extraOperands' argument.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// // Before unrolling
|
||||||
|
//
|
||||||
|
// operand0 operand1 operand2
|
||||||
|
// \ | /
|
||||||
|
// -------------------- opA --------------------
|
||||||
|
//
|
||||||
|
// // After unrolling by 2
|
||||||
|
//
|
||||||
|
// operand0 operand1 operand2
|
||||||
|
// / \ / \ / \
|
||||||
|
// slice00 slice01 slice10 slice11 slice20 slice21
|
||||||
|
// \ | | | / |
|
||||||
|
// -------------------- opA0 -------------------- |
|
||||||
|
// | | | |
|
||||||
|
// \ | | /
|
||||||
|
// -------------------- opA1 -------------------
|
||||||
|
// | |
|
||||||
|
// \ /
|
||||||
|
// insertslice
|
||||||
|
// |
|
||||||
|
|
||||||
|
// TODO(andydavis) Generalize this to support structured ops beyond
|
||||||
|
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
|
||||||
|
static Value *unrollSingleResultStructuredOp(
|
||||||
|
Operation *op, ArrayRef<int64_t> iterationBounds,
|
||||||
|
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
|
||||||
|
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
|
||||||
|
PatternRewriter &builder) {
|
||||||
|
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape())
|
||||||
|
assert(false && "Expected a statically shaped result type");
|
||||||
|
|
||||||
|
// Compute unroll factors for 'iterationBounds' based on 'targetShape'
|
||||||
|
auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape);
|
||||||
|
if (!maybeUnrollFactors.hasValue())
|
||||||
|
assert(false && "Failed to compute unroll factors for target shape");
|
||||||
|
auto unrollFactors = *maybeUnrollFactors;
|
||||||
|
|
||||||
|
// Compute unrolled operation state for each mapped operand.
|
||||||
|
unsigned numMaps = iterationIndexMapList.size();
|
||||||
|
SmallVector<UnrolledOperandState, 3> unrolledOperandState(numMaps);
|
||||||
|
assert(op->getNumOperands() >= numMaps);
|
||||||
|
for (unsigned i = 0; i < numMaps; ++i) {
|
||||||
|
getUnrolledOperandState(op->getOperand(i), iterationIndexMapList[i],
|
||||||
|
targetShape, unrolledOperandState[i]);
|
||||||
|
}
|
||||||
|
// Compute number of total unrolled instances.
|
||||||
|
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
||||||
|
auto basis = computeStrides(unrollFactors);
|
||||||
|
|
||||||
|
auto &resultOperandState = unrolledOperandState[numMaps - 1];
|
||||||
|
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
|
||||||
|
shapedType.getElementType());
|
||||||
|
|
||||||
|
// Initialize caches for intermediate vector results.
|
||||||
|
std::vector<SmallVector<Value *, 4>> caches(numMaps);
|
||||||
|
for (unsigned i = 0; i < numMaps; ++i) {
|
||||||
|
caches[i].resize(unrolledOperandState[i].numInstances);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'.
|
||||||
|
for (unsigned i = 0; i < numUnrolledInstances; ++i) {
|
||||||
|
// De-linearize w.r.t. 'basis'.
|
||||||
|
auto vectorOffsets = delinearize(i, basis);
|
||||||
|
// Convert from unrolled vector-space offsets to element-space offsets.
|
||||||
|
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||||
|
vectorOffsets, targetShape);
|
||||||
|
// Get cached slice (or create slice) for each operand at 'offsets'.
|
||||||
|
SmallVector<Value *, 3> operands;
|
||||||
|
operands.reserve(numMaps);
|
||||||
|
for (unsigned i = 0; i < numMaps; ++i) {
|
||||||
|
operands.push_back(getOrCreateUnrolledOperandSlice(
|
||||||
|
op->getLoc(), unrolledOperandState[i], vectorOffsets, offsets,
|
||||||
|
iterationIndexMapList[i], caches[i], builder));
|
||||||
|
}
|
||||||
|
// Create op on sliced vector arguments.
|
||||||
|
operands.append(extraOperands.begin(), extraOperands.end());
|
||||||
|
auto resultVector =
|
||||||
|
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
|
||||||
|
unrolledResultType)
|
||||||
|
->getResult(0);
|
||||||
|
|
||||||
|
// Compute linear result index.
|
||||||
|
int64_t resultIndex = getUnrolledOperandLinearIndex(
|
||||||
|
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
|
||||||
|
// Update result cache at 'resultIndex'.
|
||||||
|
caches[numMaps - 1][resultIndex] = resultVector;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make zero splat into which we will insert results from 'cache[numMaps - 1]'
|
||||||
|
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||||
|
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
|
||||||
|
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
|
||||||
|
// Insert vector accumulators into output.
|
||||||
|
for (unsigned i = 0; i < resultOperandState.numInstances; ++i) {
|
||||||
|
auto vectorOffsets = delinearize(i, resultOperandState.basis);
|
||||||
|
// Convert from unrolled vector-space offsets to element-space offsets.
|
||||||
|
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||||
|
vectorOffsets, resultOperandState.unrolledShape);
|
||||||
|
res = builder.create<vector::InsertStridedSliceOp>(
|
||||||
|
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
// Entry point for unrolling declarative pattern rewrites.
|
// Entry point for unrolling declarative pattern rewrites.
|
||||||
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||||
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||||
@ -200,6 +426,26 @@ static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
|
|||||||
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
|
Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
|
||||||
Operation *op,
|
Operation *op,
|
||||||
ArrayRef<int64_t> targetShape) {
|
ArrayRef<int64_t> targetShape) {
|
||||||
|
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||||
|
// Get contraction op iteration bounds.
|
||||||
|
SmallVector<int64_t, 6> iterationBounds;
|
||||||
|
contractionOp.getIterationBounds(iterationBounds);
|
||||||
|
assert(iterationBounds.size() == targetShape.size());
|
||||||
|
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||||
|
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||||
|
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||||
|
// TODO(andydavis) Support unrollable vector masks.
|
||||||
|
SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
|
||||||
|
contractionOp.masks().end());
|
||||||
|
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||||
|
return unrollSingleResultStructuredOp(op, iterationBounds,
|
||||||
|
iterationIndexMapList, targetShape,
|
||||||
|
masks, builder);
|
||||||
|
}
|
||||||
|
// TODO(andydavis) Create trivial iteration bounds and index map for
|
||||||
|
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
|
||||||
|
// fakefork/join if possible.
|
||||||
|
|
||||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
|
||||||
"]: unrollSingleResultOpMatchingType on func:\n");
|
"]: unrollSingleResultOpMatchingType on func:\n");
|
||||||
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
|
LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
|
||||||
@ -365,24 +611,6 @@ struct ConvertFakeForkFromBlockArgsOp : public RewritePattern {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
|
||||||
VectorType vt) {
|
|
||||||
auto t = vt.getElementType();
|
|
||||||
Value *f = nullptr;
|
|
||||||
if (t.isBF16() || t.isF16())
|
|
||||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF16FloatAttr(0.0f))
|
|
||||||
.getResult();
|
|
||||||
else if (t.isF32())
|
|
||||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF32FloatAttr(0.0f))
|
|
||||||
.getResult();
|
|
||||||
else if (t.isF64())
|
|
||||||
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f))
|
|
||||||
.getResult();
|
|
||||||
if (f)
|
|
||||||
return rewriter.create<SplatOp>(loc, vt, f).getResult();
|
|
||||||
llvm_unreachable("Unsupported type in `makeSplatZero`");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
|
// Rewrites a fakeJoin, whose (unique) operand is a blockArgument, into multiple
|
||||||
// vector.strided_slice ops.
|
// vector.strided_slice ops.
|
||||||
struct ConvertFakeJoinOp : public RewritePattern {
|
struct ConvertFakeJoinOp : public RewritePattern {
|
||||||
|
Loading…
Reference in New Issue
Block a user