Add fine-grined shape constraints per type to Add, Mul, and Sub ops
PiperOrigin-RevId: 315784886 Change-Id: I385d61510eb905ad6c8a0679c99a78bbfad6dc44
This commit is contained in:
parent
a2868d9d55
commit
83af443dc7
@ -525,11 +525,16 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
auto *val = trait.getDef().getValue("tflRuntimePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
auto desc = trait.getDef().getValueAsString("tflRuntimeDescription");
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
|
||||
" if (failure_on_operand_type_mismatch) {\n"
|
||||
" return top.emitOpError(\"failed to verify that $1\");\n"
|
||||
" } else {\n"
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
}
|
||||
|
@ -50,9 +50,8 @@ namespace TFL {
|
||||
// broadcastable shape within the given rank. If any given shapes are
|
||||
// non-static and maximum rank is within the given rank, this method returns
|
||||
// true.
|
||||
bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
|
||||
ArrayRef<unsigned> indices,
|
||||
int max_bcast_rank) {
|
||||
bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
|
||||
if (indices.empty()) return true;
|
||||
|
||||
// First, it checks there are any inputs that has unknown rank.
|
||||
@ -110,6 +109,122 @@ bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
|
||||
return has_same_shape || max_rank <= max_bcast_rank;
|
||||
}
|
||||
|
||||
// Return true when the given element_type is QI8.
|
||||
bool IsQI8Type(Type element_type) {
|
||||
auto quantized_type = element_type.dyn_cast<QuantizedType>();
|
||||
return quantized_type != nullptr &&
|
||||
quantized_type.getStorageTypeIntegralWidth() == 8 &&
|
||||
quantized_type.isSigned();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is QUI8.
|
||||
bool IsQUI8Type(Type element_type) {
|
||||
auto quantized_type = element_type.dyn_cast<QuantizedType>();
|
||||
return quantized_type != nullptr &&
|
||||
quantized_type.getStorageTypeIntegralWidth() == 8 &&
|
||||
!quantized_type.isSigned();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is QI16.
|
||||
bool IsQI16Type(Type element_type) {
|
||||
auto quantized_type = element_type.dyn_cast<QuantizedType>();
|
||||
return quantized_type != nullptr &&
|
||||
quantized_type.getStorageTypeIntegralWidth() == 16 &&
|
||||
quantized_type.isSigned();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is I32.
|
||||
bool IsI32Type(Type element_type) {
|
||||
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true if the given Add operation has the CPU kernel supported shapes.
|
||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsQI8Type(element_type) ||
|
||||
IsQUI8Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
|
||||
// Allows QI16 output when operands have the same shape.
|
||||
if (IsQI16Type(element_type)) {
|
||||
return succeeded(
|
||||
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return true if the given Sub operation has the CPU kernel supported shapes.
|
||||
bool VerifySubOpShapeConstraints(SubOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsI32Type(element_type) ||
|
||||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows QI8 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsQI8Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return true if the given Mul operation has the CPU kernel supported shapes.
|
||||
bool VerifyMulOpShapeConstraints(MulOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
|
||||
// output type is not QI16. If the output type is Q16, allows onlt the same
|
||||
// shape operands.
|
||||
if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
|
||||
if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
|
||||
return succeeded(
|
||||
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
|
||||
}
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows F32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlowLiteDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -127,7 +127,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape<
|
||||
list<int> indices, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
|
||||
"broadcastable shapes within the rank " # max_bcast_rank,
|
||||
CPred<"TFL::IsOperandsHaveSameShapesOrBroadcastableShape("
|
||||
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
|
||||
"}), " # max_bcast_rank # ")">>;
|
||||
|
||||
@ -491,7 +491,8 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
@ -2171,7 +2172,8 @@ def TFL_MulOp : TFL_Op<"mul", [
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Multiplication operator";
|
||||
|
||||
@ -2832,7 +2834,8 @@ def TFL_SquareOp: TFL_Op<"square", [
|
||||
def TFL_SubOp : TFL_Op<"sub", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifySubOpShapeConstraints(llvm::cast<SubOp>($_op))">>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Subtraction operator";
|
||||
|
||||
|
@ -1548,3 +1548,12 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
|
||||
// CHECK-LABEL: maximum_with_6d_broadcasting
|
||||
// CHECK: "tf.Maximum"(%arg0, %arg1)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
|
||||
return %0 : tensor<1x1x1x3x4xi32>
|
||||
// CHECK-LABEL: add_with_int32_5d_inputs
|
||||
// CHECK: "tf.Add"(%arg0, %arg1)
|
||||
}
|
||||
|
@ -277,6 +277,52 @@ func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @add_with_quantized_i16_broadcasting(tensor<2x2xf32>, tensor<1xf32>) -> tensor<2x2x!quant.any<i16:f32>> {
|
||||
^bb0(%arg0: tensor<2x2xf32>, %arg1: tensor<1xf32>):
|
||||
// expected-error @+1 {{Operands do not have valid shapes}}
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<2x2xf32>, tensor<1xf32>) -> tensor<2x2x!quant.any<i16:f32>>
|
||||
return %0#0 : tensor<2x2x!quant.any<i16:f32>>
|
||||
}
|
||||
// -----
|
||||
|
||||
func @sub_with_quantized_i8_five_dim_broadcasting(tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x1x!quant.any<i8:f32>> {
|
||||
^bb0(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1xf32>):
|
||||
// expected-error @+1 {{Operands do not have valid shapes}}
|
||||
%0 = "tfl.sub"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x1x!quant.any<i8:f32>>
|
||||
return %0#0 : tensor<1x1x1x1x1x!quant.any<i8:f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mul_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> {
|
||||
^bb0(%arg0: tensor<1x1x1x1x1xi32>, %arg1: tensor<1xi32>):
|
||||
// expected-error @+1 {{Operands do not have valid shapes}}
|
||||
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32>
|
||||
return %0#0 : tensor<1x1x1x1x1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mul_with_quantized_i16_five_dim_broadcasting(tensor<1x1x1x1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x1x1x1x!quant.any<i16:f32>> {
|
||||
^bb0(%arg0: tensor<1x1x1x1x1x!quant.any<i16:f32>>, %arg1: tensor<1x!quant.any<i16:f32>>):
|
||||
// expected-error @+1 {{Operands do not have valid shapes}}
|
||||
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x1x1x1x!quant.any<i16:f32>>
|
||||
return %0#0 : tensor<1x1x1x1x1x!quant.any<i16:f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mul_with_quantized_i16_to_uint8_broadcasting(tensor<1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x!quant.any<ui8:f32>> {
|
||||
^bb0(%arg0: tensor<1x1x!quant.any<i16:f32>>, %arg1: tensor<1x!quant.any<i16:f32>>):
|
||||
// expected-error @+1 {{Operands do not have valid shapes}}
|
||||
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x!quant.any<ui8:f32>>
|
||||
return %0#0 : tensor<1x1x!quant.any<ui8:f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testMulNonQuantizedOperandsandQuantizedResult
|
||||
func @testMulNonQuantizedOperandsandQuantizedResult(tensor<? x f32>, tensor<? x f32>) -> tensor<? x !quant.any<i16:f32>> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
|
||||
|
Loading…
Reference in New Issue
Block a user