Add fine-grined shape constraints per type to Add, Mul, and Sub ops

PiperOrigin-RevId: 315784886
Change-Id: I385d61510eb905ad6c8a0679c99a78bbfad6dc44
This commit is contained in:
Jaesung Chung 2020-06-10 15:47:44 -07:00 committed by TensorFlower Gardener
parent a2868d9d55
commit 83af443dc7
5 changed files with 187 additions and 9 deletions

View File

@ -525,11 +525,16 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue;
auto desc = trait.getDef().getValueAsString("tflRuntimeDescription");
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
" if (failure_on_operand_type_mismatch) {\n"
" return top.emitOpError(\"failed to verify that $1\");\n"
" } else {\n"
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
}
os << " return top.verify();\n}\n";
}

View File

@ -50,9 +50,8 @@ namespace TFL {
// broadcastable shape within the given rank. If any given shapes are
// non-static and maximum rank is within the given rank, this method returns
// true.
bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
ArrayRef<unsigned> indices,
int max_bcast_rank) {
bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
if (indices.empty()) return true;
// First, it checks there are any inputs that has unknown rank.
@ -110,6 +109,122 @@ bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
return has_same_shape || max_rank <= max_bcast_rank;
}
// Return true when the given element_type is QI8.
bool IsQI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
quantized_type.isSigned();
}
// Return true when the given element_type is QUI8.
bool IsQUI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
!quantized_type.isSigned();
}
// Return true when the given element_type is QI16.
bool IsQI16Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 16 &&
quantized_type.isSigned();
}
// Return true when the given element_type is I32.
bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
// Allows QI16 output when operands have the same shape.
if (IsQI16Type(element_type)) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return false;
}
// Return true if the given Sub operation has the CPU kernel supported shapes.
bool VerifySubOpShapeConstraints(SubOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows QI8 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsQI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
// Return true if the given Mul operation has the CPU kernel supported shapes.
bool VerifyMulOpShapeConstraints(MulOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
// output type is not QI16. If the output type is Q16, allows onlt the same
// shape operands.
if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows F32 output when the operands have valid shapes, which are
// broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//

View File

@ -127,7 +127,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape<
list<int> indices, int max_bcast_rank> :
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
"broadcastable shapes within the rank " # max_bcast_rank,
CPred<"TFL::IsOperandsHaveSameShapesOrBroadcastableShape("
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
"}), " # max_bcast_rank # ")">>;
@ -491,7 +491,8 @@ an output element, this operation computes \\(y = |x|\\).
}
def TFL_AddOp : TFL_Op<"add", [
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
ResultsBroadcastableShape,
NoSideEffect,
Commutative,
@ -2171,7 +2172,8 @@ def TFL_MulOp : TFL_Op<"mul", [
NoSideEffect,
Commutative,
BinaryOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>,
TFL_GpuTargetOp]> {
let summary = "Multiplication operator";
@ -2832,7 +2834,8 @@ def TFL_SquareOp: TFL_Op<"square", [
def TFL_SubOp : TFL_Op<"sub", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
CPred<"TFL::VerifySubOpShapeConstraints(llvm::cast<SubOp>($_op))">>,
NoSideEffect]> {
let summary = "Subtraction operator";

View File

@ -1548,3 +1548,12 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
// CHECK-LABEL: maximum_with_6d_broadcasting
// CHECK: "tf.Maximum"(%arg0, %arg1)
}
// -----
func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
return %0 : tensor<1x1x1x3x4xi32>
// CHECK-LABEL: add_with_int32_5d_inputs
// CHECK: "tf.Add"(%arg0, %arg1)
}

View File

@ -277,6 +277,52 @@ func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
return %0#0 : tensor<? x i32>
}
// -----
func @add_with_quantized_i16_broadcasting(tensor<2x2xf32>, tensor<1xf32>) -> tensor<2x2x!quant.any<i16:f32>> {
^bb0(%arg0: tensor<2x2xf32>, %arg1: tensor<1xf32>):
// expected-error @+1 {{Operands do not have valid shapes}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<2x2xf32>, tensor<1xf32>) -> tensor<2x2x!quant.any<i16:f32>>
return %0#0 : tensor<2x2x!quant.any<i16:f32>>
}
// -----
func @sub_with_quantized_i8_five_dim_broadcasting(tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x1x!quant.any<i8:f32>> {
^bb0(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1xf32>):
// expected-error @+1 {{Operands do not have valid shapes}}
%0 = "tfl.sub"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x1x!quant.any<i8:f32>>
return %0#0 : tensor<1x1x1x1x1x!quant.any<i8:f32>>
}
// -----
func @mul_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> {
^bb0(%arg0: tensor<1x1x1x1x1xi32>, %arg1: tensor<1xi32>):
// expected-error @+1 {{Operands do not have valid shapes}}
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32>
return %0#0 : tensor<1x1x1x1x1xi32>
}
// -----
func @mul_with_quantized_i16_five_dim_broadcasting(tensor<1x1x1x1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x1x1x1x!quant.any<i16:f32>> {
^bb0(%arg0: tensor<1x1x1x1x1x!quant.any<i16:f32>>, %arg1: tensor<1x!quant.any<i16:f32>>):
// expected-error @+1 {{Operands do not have valid shapes}}
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x1x1x1x!quant.any<i16:f32>>
return %0#0 : tensor<1x1x1x1x1x!quant.any<i16:f32>>
}
// -----
func @mul_with_quantized_i16_to_uint8_broadcasting(tensor<1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x!quant.any<ui8:f32>> {
^bb0(%arg0: tensor<1x1x!quant.any<i16:f32>>, %arg1: tensor<1x!quant.any<i16:f32>>):
// expected-error @+1 {{Operands do not have valid shapes}}
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x!quant.any<i16:f32>>, tensor<1x!quant.any<i16:f32>>) -> tensor<1x1x!quant.any<ui8:f32>>
return %0#0 : tensor<1x1x!quant.any<ui8:f32>>
}
// -----
// CHECK-LABEL: testMulNonQuantizedOperandsandQuantizedResult
func @testMulNonQuantizedOperandsandQuantizedResult(tensor<? x f32>, tensor<? x f32>) -> tensor<? x !quant.any<i16:f32>> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):