[VectorOps] Add a BroadcastOp to the VectorOps dialect

PiperOrigin-RevId: 282643305
Change-Id: Ide3a0cf42204977e275c86647475af079c32a7c3
This commit is contained in:
Aart Bik 2019-11-26 14:43:03 -08:00 committed by TensorFlower Gardener
parent d693915624
commit 0522d23e65
2 changed files with 69 additions and 0 deletions

View File

@ -162,6 +162,34 @@ def Vector_ContractionOp :
}]; }];
} }
def Vector_BroadcastOp :
Vector_Op<"broadcast", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
Arguments<(ins AnyType:$source, AnyVector:$dest)>,
Results<(outs AnyVector:$vector)> {
let summary = "broadcast operation";
let description = [{
Broadcasts the scalar or k-D vector value in the source to the n-D
destination vector of a proper shape such that the broadcast makes sense.
Examples:
```
%0 = constant 0.0 : f32
%1 = vector.broadcast %0, %x : f32 into vector<16xf32>
%2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32>
```
}];
let extraClassDeclaration = [{
Type getSourceType() { return source()->getType(); }
VectorType getDestVectorType() {
return dest()->getType().cast<VectorType>();
}
}];
}
def Vector_ExtractElementOp : def Vector_ExtractElementOp :
Vector_Op<"extractelement", [NoSideEffect, Vector_Op<"extractelement", [NoSideEffect,
PredOpTrait<"operand and result have same element type", PredOpTrait<"operand and result have same element type",

View File

@ -368,6 +368,47 @@ static LogicalResult verify(ExtractElementOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BroadcastOp op) {
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
p << " : " << op.getSourceType();
p << " into " << op.getDestVectorType();
}
static LogicalResult verify(BroadcastOp op) {
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
VectorType dstVectorType = op.getDestVectorType();
// Scalar to vector broadcast is always valid. A vector
// to vector broadcast needs some additional checking.
if (srcVectorType) {
const int64_t srcRank = srcVectorType.getRank();
const int64_t dstRank = dstVectorType.getRank();
// TODO(ajcbik): implement proper rank testing for broadcast;
// this is just a temporary placeholder check.
if (srcRank > dstRank) {
return op.emitOpError("source rank higher than destination rank");
}
}
return success();
}
static ParseResult parseBroadcastOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType source, dest;
Type sourceType;
VectorType destType;
return failure(parser.parseOperand(source) || parser.parseComma() ||
parser.parseOperand(dest) ||
parser.parseColonType(sourceType) ||
parser.parseKeywordType("into", destType) ||
parser.resolveOperand(source, sourceType, result.operands) ||
parser.resolveOperand(dest, destType, result.operands) ||
parser.addTypeToList(destType, result.types));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InsertElementOp // InsertElementOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//