From 0f85e658ab8a50e713ec390593bfc8b06cba14e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Dec 2019 16:04:59 -0800 Subject: [PATCH] [VectorOps] Update vector transfer_read/write ops to operatate on memrefs with vector element type. Update vector transfer_read/write ops to operatate on memrefs with vector element type. This handle cases where the memref vector element type represents the minimal memory transfer unit (or multiple of the minimal memory transfer unit). PiperOrigin-RevId: 286482115 Change-Id: I063aa16a82a171d8ca63adde2f6391af2485c1c1 --- .../mlir/Dialect/VectorOps/VectorOps.td | 38 ++++-- .../mlir/lib/Dialect/VectorOps/VectorOps.cpp | 121 +++++++++++++----- 2 files changed, 116 insertions(+), 43 deletions(-) diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 7dcac62a585..d5e84314357 100644 --- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -746,10 +746,15 @@ def Vector_TransferReadOp : let description = [{ The `vector.transfer_read` op performs a blocking read from a slice within - a scalar [MemRef](../LangRef.md#memref-type) supplied as its first operand - into a [vector](../LangRef.md#vector-type) of the same elemental type. The - slice is further defined by a full-rank index within the MemRef, supplied as - the operands `2 .. 1 + rank(memref)`. The permutation_map + a [MemRef](../LangRef.md#memref-type) supplied as its first operand + into a [vector](../LangRef.md#vector-type) of the same base elemental type. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The size of the slice is specified by the @@ -854,6 +859,11 @@ def Vector_TransferReadOp : memref, vector<128xf32> } } + + // Read from a memref with vector element type. + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 + {permutation_map = (d0, d1)->(d0, d1)} + : memref>, vector<1x1x4x3xf32> ``` }]; @@ -878,10 +888,15 @@ def Vector_TransferWriteOp : let description = [{ The `vector.transfer_write` performs a blocking write from a [vector](../LangRef.md#vector-type), supplied as its first operand, into a - slice within a scalar [MemRef](../LangRef.md#memref-type) of the same - elemental type, supplied as its second operand. The slice is further defined - by a full-rank index within the MemRef, supplied as the operands - `3 .. 2 + rank(memref)`. + slice within a [MemRef](../LangRef.md#memref-type) of the same base + elemental type, supplied as its second operand. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `3 .. 2 + rank(memref)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The size of the slice is specified by the @@ -915,6 +930,11 @@ def Vector_TransferWriteOp : {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : vector<16x32x64xf32>, memref }}}} + + // write to a memref with vector element type. + vector.transfer_write %4, %arg1[%c3, %c3] + {permutation_map = (d0, d1)->(d0, d1)} + : vector<1x1x4x3xf32>, memref> ``` }]; @@ -1048,7 +1068,7 @@ def Vector_TupleOp : Note that this operation is used during the vector op unrolling transformation and should be removed before lowering to lower-level dialects. - + Examples: ``` diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 541b5427af9..8a6946792b2 100644 --- a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1420,6 +1420,59 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap, return success(); } +static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, + VectorType vectorType, + AffineMap permutationMap) { + auto memrefElementType = memrefType.getElementType(); + if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { + // Memref has vector element type. + + // Check that 'memrefVectorElementType' and vector element types match. + if (memrefVectorElementType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + // Check that memref vector type is a suffix of 'vectorType. + unsigned memrefVecEltRank = memrefVectorElementType.getRank(); + unsigned resultVecRank = vectorType.getRank(); + if (memrefVecEltRank > resultVecRank) + return op->emitOpError( + "requires memref vector element and vector result ranks to match."); + // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h. + unsigned rankOffset = resultVecRank - memrefVecEltRank; + auto memrefVecEltShape = memrefVectorElementType.getShape(); + auto resultVecShape = vectorType.getShape(); + for (unsigned i = 0; i < memrefVecEltRank; ++i) + if (memrefVecEltShape[i] != resultVecShape[rankOffset + i]) + return op->emitOpError( + "requires memref vector element shape to match suffix of " + "vector result shape."); + // Check that permutation map results match 'rankOffset' of vector type. + if (permutationMap.getNumResults() != rankOffset) + return op->emitOpError("requires a permutation_map with result dims of " + "the same rank as the vector type"); + } else { + // Memref has scalar element type. + + // Check that memref and vector element types match. + if (memrefType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + // Check that permutation map results match rank of vector type. + if (permutationMap.getNumResults() != vectorType.getRank()) + return op->emitOpError("requires a permutation_map with result dims of " + "the same rank as the vector type"); + } + + if (permutationMap.getNumSymbols() != 0) + return op->emitOpError("requires permutation_map without symbols"); + if (permutationMap.getNumInputs() != memrefType.getRank()) + return op->emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type"); + return success(); +} + static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.memref() << "[" << op.indices() << "], " << op.padding() << " "; @@ -1459,26 +1512,35 @@ static LogicalResult verify(TransferReadOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - if (memrefType.getElementType() != vectorType.getElementType()) - return op.emitOpError( - "requires memref and vector types of the same elemental type"); - auto elementalType = op.padding()->getType(); - if (!VectorType::isValidElementType(elementalType)) - return op.emitOpError("requires valid padding vector elemental type"); - if (elementalType != vectorType.getElementType()) - return op.emitOpError( - "requires formal padding and vector of the same elemental type"); - if (llvm::size(op.indices()) != memrefType.getRank()) - return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + auto paddingType = op.padding()->getType(); auto permutationMap = op.permutation_map(); - if (permutationMap.getNumSymbols() != 0) - return op.emitOpError("requires permutation_map without symbols"); - if (permutationMap.getNumInputs() != memrefType.getRank()) - return op.emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type"); - if (permutationMap.getNumResults() != vectorType.getRank()) - return op.emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type"); + auto memrefElementType = memrefType.getElementType(); + + if (static_cast(op.indices().size()) != memrefType.getRank()) + return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + + if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + permutationMap))) + return failure(); + + if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { + // Memref has vector element type. + // Check that 'memrefVectorElementType' and 'paddingType' types match. + if (memrefVectorElementType != paddingType) + return op.emitOpError( + "requires memref element type and padding type to match."); + + } else { + // Check that 'paddingType' is valid to store in a vector type. + if (!VectorType::isValidElementType(paddingType)) + return op.emitOpError("requires valid padding vector elemental type"); + + // Check that padding type and vector element types match. + if (paddingType != vectorType.getElementType()) + return op.emitOpError( + "requires formal padding and vector of the same elemental type"); + } + return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } @@ -1519,24 +1581,15 @@ static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - if (memrefType.getElementType() != vectorType.getElementType()) - return op.emitOpError( - "requires memref and vector types of the same elemental type"); + auto permutationMap = op.permutation_map(); + if (llvm::size(op.indices()) != memrefType.getRank()) return op.emitOpError("requires ") << memrefType.getRank() << " indices"; - // Consistency of AffineMap attribute. - auto permutationMap = op.permutation_map(); - if (permutationMap.getNumSymbols() != 0) - return op.emitOpError("requires a symbol-less permutation_map"); - if (permutationMap.getNumInputs() != memrefType.getRank()) - return op.emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type: ") - << permutationMap.getNumInputs() << " vs " << memrefType; - if (permutationMap.getNumResults() != vectorType.getRank()) - return op.emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type.") - << permutationMap.getNumResults() << " vs " << vectorType; + if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + permutationMap))) + return failure(); + return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); }