diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index a887a3e4e79..34c2fa97e53 100644 --- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -162,6 +162,34 @@ def Vector_ContractionOp : }]; } +def Vector_BroadcastOp : + Vector_Op<"broadcast", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$source, AnyVector:$dest)>, + Results<(outs AnyVector:$vector)> { + let summary = "broadcast operation"; + let description = [{ + Broadcasts the scalar or k-D vector value in the source to the n-D + destination vector of a proper shape such that the broadcast makes sense. + + Examples: + ``` + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0, %x : f32 into vector<16xf32> + %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32> + ``` + }]; + let extraClassDeclaration = [{ + Type getSourceType() { return source()->getType(); } + VectorType getDestVectorType() { + return dest()->getType().cast(); + } + }]; +} + def Vector_ExtractElementOp : Vector_Op<"extractelement", [NoSideEffect, PredOpTrait<"operand and result have same element type", diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp index b73b771d80d..d09fd0fc2f2 100644 --- a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -368,6 +368,47 @@ static LogicalResult verify(ExtractElementOp op) { return success(); } +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, BroadcastOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest(); + p << " : " << op.getSourceType(); + p << " into " << op.getDestVectorType(); +} + +static LogicalResult verify(BroadcastOp op) { + VectorType srcVectorType = op.getSourceType().dyn_cast(); + VectorType dstVectorType = op.getDestVectorType(); + // Scalar to vector broadcast is always valid. A vector + // to vector broadcast needs some additional checking. + if (srcVectorType) { + const int64_t srcRank = srcVectorType.getRank(); + const int64_t dstRank = dstVectorType.getRank(); + // TODO(ajcbik): implement proper rank testing for broadcast; + // this is just a temporary placeholder check. + if (srcRank > dstRank) { + return op.emitOpError("source rank higher than destination rank"); + } + } + return success(); +} + +static ParseResult parseBroadcastOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType source, dest; + Type sourceType; + VectorType destType; + return failure(parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("into", destType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===//