diff --git a/third_party/mlir/lib/IR/Operation.cpp b/third_party/mlir/lib/IR/Operation.cpp index 25302e5ff06..27681d37f17 100644 --- a/third_party/mlir/lib/IR/Operation.cpp +++ b/third_party/mlir/lib/IR/Operation.cpp @@ -767,7 +767,7 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) { } LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { - if (op->getNumOperands() == 0) + if (failed(verifyAtLeastNOperands(op, 1))) return failure(); auto type = op->getOperand(0)->getType(); @@ -779,7 +779,8 @@ LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { } LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) return failure(); auto type = op->getOperand(0)->getType(); @@ -797,7 +798,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { } LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { - if (op->getNumOperands() == 0) + if (failed(verifyAtLeastNOperands(op, 1))) return failure(); auto type = op->getOperand(0)->getType().dyn_cast(); @@ -818,7 +819,8 @@ LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { LogicalResult OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) return failure(); auto type = op->getResult(0)->getType().dyn_cast(); @@ -850,7 +852,8 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { } LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) return failure(); auto type = op->getResult(0)->getType(); diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td index e419b7ef3b1..944ce79a182 100644 --- a/third_party/mlir/test/lib/TestDialect/TestOps.td +++ b/third_party/mlir/test/lib/TestDialect/TestOps.td @@ -219,19 +219,19 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_type", def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type", [SameOperandsAndResultElementType]> { - let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y); - let results = (outs AnyVectorOrTensor:$res); + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$res); } def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> { - let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y); + let arguments = (ins Variadic:$args); let results = (outs AnyVectorOrTensor:$res); } def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape", [SameOperandsAndResultShape]> { - let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y); - let results = (outs AnyVectorOrTensor:$res); + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$res); } def ArgAndResHaveFixedElementTypesOp :