Support Conv3d in MLIR converter

PiperOrigin-RevId: 354876135
Change-Id: I93710001e202a8b8ea7e9cbc3a1983e30fd7dcfe
This commit is contained in:
Thai Nguyen 2021-01-31 23:34:07 -08:00 committed by TensorFlower Gardener
parent 5250b2e620
commit 6b23dbc159
7 changed files with 247 additions and 11 deletions

View File

@ -195,6 +195,10 @@ def TFL_StatefulTensor : TypeAlias<AnyTensor, "stateful tensor">;
// Rank/Shape helpers.
//===----------------------------------------------------------------------===//
// Returns true of operand is none type.
class TFL_OperandIsNoneType<int i> :
CPred<"$_op.getOperand(" # i # ").getType().isa<NoneType>()">;
class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
@ -256,6 +260,44 @@ class TFL_Operand0DOr1ElementTensor<int x> :
Or<[TFL_OperandHasKnownRank<x, 0>,
And<[TFL_OperandHasKnownRank<x, 1>, TFL_OperandDimEquals<x, 0, 1>]>]>>;
// Return true if i-th dim of x-th operand is the same as j-th dim of y-th
// operand or any of those operands does not have static shape.
class TFL_OperandsHaveSameDims<int x, int y, int i, int j> :
Or<[TFL_OperandIsUnrankedPred<x>,
TFL_OperandIsUnrankedPred<y>,
CPred<"!$_op.getOperand(" # x #
").getType().cast<ShapedType>().hasStaticShape()">,
CPred<"!$_op.getOperand(" # y #
").getType().cast<ShapedType>().hasStaticShape()">,
CPred<"$_op.getOperand(" # x #
").getType().cast<ShapedType>().getShape()[" # i # "] == "
"$_op.getOperand(" # y #
").getType().cast<ShapedType>().getShape()[" # j # "]">]>;
class TFL_OperandsHaveSameDimsTrait<int x, int y, int i, int j> :
PredOpTrait<"dim " # i # " of operand " # x # " equals to dim " # j #
" of operand " # y,
TFL_OperandsHaveSameDims<x, y, i, j>>;
// Return true if number of elements of x-th operand is the same as j-th dim of
// y-th operand or any of those operands does not have static shape.
class TFL_NumElementsEqualsDim<int x, int y, int j> :
Or<[TFL_OperandIsUnrankedPred<x>,
TFL_OperandIsUnrankedPred<y>,
CPred<"!$_op.getOperand(" # x #
").getType().cast<ShapedType>().hasStaticShape()">,
CPred<"!$_op.getOperand(" # y #
").getType().cast<ShapedType>().hasStaticShape()">,
CPred<"$_op.getOperand(" # x #
").getType().cast<ShapedType>().getNumElements() == "
"$_op.getOperand(" # y #
").getType().cast<ShapedType>().getShape()[" # j # "]">]>;
class TFL_NumElementsEqualsDimTrait<int x, int y, int j> :
PredOpTrait<"operand " # x # " has num of elements equals to dim " # j #
" of operand " # y,
TFL_NumElementsEqualsDim<x, y, j>>;
// tf.uint8 and tf.quint8 are mapped to the same tflite types, so they are equal
// when used as element types.
class TFL_TFTypesWithSameBits<int i, int j, int num> :
@ -275,7 +317,7 @@ class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
class TFL_OperandIsNoneOrHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
TFL_OperandIsNoneType<n>,
TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
").getType().cast<ShapedType>().getRank() == " # m>]>>;
@ -283,7 +325,7 @@ class TFL_OperandIsNoneOrHasRank<int n, int m> :
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
PredOpTrait<"operand " # n # " is at most " # m # "-D",
Or<[
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
TFL_OperandIsNoneType<n>,
TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
@ -4556,7 +4598,6 @@ the dimension is padded with zeros.
);
}
def TFL_AssignVariableOp : TFL_Op<"assign_variable", []> {
let summary = "Assigns a new value to a variable.";
@ -4593,4 +4634,52 @@ Read variable data identified by 'resource_id'.
let results = (outs TFL_TensorOf<[F32]>:$result);
}
def TFL_Conv3DOp : TFL_Op<"conv_3d", [
NoSideEffect,
AccumulatorUniformScale<2, 0, 1>,
TFL_OperandHasRank<0, 5>,
TFL_OperandHasRank<1, 5>,
// Channel dimension in input and filter should match.
TFL_OperandsHaveSameDimsTrait<0, 1, 4, 3>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
PredOpTrait<"bias and output must have same element type",
Or<[
TFL_OperandIsNoneType<2>,
TFL_TCresVTEtIsSameAsOp<0, 2>]>>,
PredOpTrait<"bias must has num of elements equals to 4th dim of filter",
Or<[
TFL_OperandIsNoneType<2>,
TFL_NumElementsEqualsDim<2, 1, 4>]>>]> {
let summary = "Convolution 3D operator";
let description = [{
Performs convolution operation on 3D inputs.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (ins
TFL_TensorOf<[F32]>:$input,
TFL_TensorOf<[F32]>:$filter,
TFL_TensorOfOrNone<[F32]>:$bias,
I32Attr:$dilation_d_factor,
I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function,
TFL_PaddingAttr:$padding,
I32Attr:$stride_d,
I32Attr:$stride_h,
I32Attr:$stride_w
);
let results = (outs TFL_TensorOf<[F32]>:$output);
let hasOptions = 1;
let customOption = "Conv3DOptions";
}
#endif // TFL_OPS

View File

@ -1996,3 +1996,22 @@ func @rfft2d_invalid(%arg0: tensor<10x20x10x30xf64>, %arg1: tensor<2xi32>) -> te
// CHECK-LABEL: rfft2d_invalid
// CHECK-NOT: "tfl.RFFT2D"
}
func @conv3d_valid(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
%0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
return %0: tensor<?x?x?x?x?xf32>
// CHECK-LABEL: conv3d_valid
// CHECK: %cst = constant unit
// CHECK: [[BCT:%.*]] = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, none) -> tensor<?x?x?x?x?xf32>
// CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32>
}
func @conv3d_invalid_strides(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
%0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
return %0: tensor<?x?x?x?x?xf32>
// CHECK-LABEL: conv3d_invalid_strides
// CHECK: [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32>
}

View File

@ -2520,3 +2520,51 @@ func @testFillWithQI8(%arg0: tensor<1x4xi32>, %arg1: tensor<? x !quant.uniform<i
%0 = "tfl.fill"(%arg0, %arg1): (tensor<1x4xi32>, tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
}
// -----
// CHECK-LABEL: testConv3dWithFloatInput
func @testConv3dWithFloatInput(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>,%arg2: tensor<?xf32>) -> tensor<?x?x?x?x?xf32> {
// CHECK: "tfl.conv_3d"(%arg0, %arg1, %arg2)
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// CHECK-LABEL: testConv3dNoBiasInput
func @testConv3dNoBiasInput(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>,%arg2: none) -> tensor<?x?x?x?x?xf32> {
// CHECK: "tfl.conv_3d"(%arg0, %arg1, %arg2)
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, none) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// -----
func @testConv3dInvalidFilterShape(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2x2x3x3xf32>,%arg2: tensor<?xf32>) -> tensor<?x?x?x?x?xf32> {
// expected-error @+1 {{failed to verify that dim 4 of operand 0 equals to dim 3 of operand 1}}
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x3x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// -----
func @testConv3dInvalidBiasShape(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2x2x2x3xf32>,%arg2: tensor<4xf32>) -> tensor<?x?x?x?x?xf32> {
// expected-error @+1 {{failed to verify that bias must has num of elements equals to 4th dim of filter}}
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<4xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// -----
func @testConv3dMisMatchInputType(%arg0: tensor<2x3x4x5x2xi32>,%arg1: tensor<2x2x2x2x3xf32>,%arg2: tensor<3xf32>) -> tensor<?x?x?x?x?xf32> {
// expected-error @+1 {{op failed to verify that input and output must have same element type}}
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xi32>, tensor<2x2x2x2x3xf32>, tensor<3xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// -----
func @testConv3dMisMatchBiasType(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2x2x2x3xf32>,%arg2: tensor<3xi32>) -> tensor<?x?x?x?x?xf32> {
// expected-error @+1 {{failed to verify that bias and output must have same element type}}
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<3xi32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}

View File

@ -150,6 +150,7 @@ DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(RandomUniform);
DECL_CONVERT_OP(Conv3D);
#undef DECL_CONVERT_OP
@ -338,6 +339,46 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite(
return success();
}
LogicalResult ConvertTFConv3DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (!TFDataFormatIsNDHWC(op)) return failure();
auto tf_op = cast<TF::Conv3DOp>(op);
IntegerAttr stride_depth, stride_height, stride_width;
if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
&stride_width))
return failure();
IntegerAttr dilation_depth_factor, dilation_height_factor,
dilation_width_factor;
if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
&dilation_height_factor, &dilation_width_factor)) {
// If the 'dilations' attribute is missing, we use the default value (1)
// for all dilation depth, height and width factor.
dilation_depth_factor = rewriter.getI32IntegerAttr(1);
dilation_height_factor = rewriter.getI32IntegerAttr(1);
dilation_width_factor = rewriter.getI32IntegerAttr(1);
}
StringAttr padding;
if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
// TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
// with other ops can fill the bias.
Value none = rewriter.create<mlir::ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
rewriter.replaceOpWithNewOp<TFL::Conv3DOp>(
op, tf_op.getType(), tf_op.input(), tf_op.filter(),
/*bias=*/none, dilation_depth_factor, dilation_height_factor,
dilation_width_factor,
/*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
stride_depth, stride_height, stride_width);
return success();
}
// MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
// only has effects when processing multiple diagonals. Since TFLite converts
// MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we
@ -692,7 +733,7 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
ConvertTFRandomUniformOp>(context);
ConvertTFRandomUniformOp, ConvertTFConv3DOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -61,6 +61,33 @@ bool TFIntListIs1XY1(const ArrayAttr &attr) {
return true;
}
// Returns true if the given `op`
// * has an attribute with the given `name`,
// * and the attribute is an integer list of the form [1, X, Y, Z, 1],
// and writes X, Y as 32-bit integer attribute to `x`, `y`, z.
bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x,
IntegerAttr *y, IntegerAttr *z) {
auto attr = op->getAttrOfType<ArrayAttr>(name);
if (!attr) return false;
auto elements = attr.getValue();
if (elements.size() != 5 ||
std::any_of(elements.begin(), elements.end(),
[](Attribute e) { return !e.isa<IntegerAttr>(); }))
return false;
if (elements.front().cast<IntegerAttr>().getInt() != 1 ||
elements.back().cast<IntegerAttr>().getInt() != 1)
return false;
Builder b(op->getContext());
*x = b.getI32IntegerAttr(elements[1].cast<IntegerAttr>().getInt());
*y = b.getI32IntegerAttr(elements[2].cast<IntegerAttr>().getInt());
*z = b.getI32IntegerAttr(elements[3].cast<IntegerAttr>().getInt());
return true;
}
// Returns true if every element of the attribute is 1. All elements of `attr`
// must be `IntegerAttr`.
bool TFIntListIsAllOnes(const ArrayAttr &attr) {

View File

@ -35,6 +35,14 @@ inline bool TFDataFormatIsNHWC(Operation *op) {
return !attr || attr.getValue() == "NHWC";
}
// Returns true if the given TensorFlow op does not have a `data_format`
// attribute (then default to "NDHWC"), or its `data_format` attribute is
// "NDHWC".
inline bool TFDataFormatIsNDHWC(Operation *op) {
auto attr = op->getAttrOfType<StringAttr>("data_format");
return !attr || attr.getValue() == "NDHWC";
}
// Returns true if the given `op`
// * has an attribute with the given `name`,
// * and the attribute is an integer list of the form [1, X, Y, 1],
@ -45,6 +53,13 @@ bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x,
// Returns true if the attribute is an integer list of the form [1, X, Y, 1],
bool TFIntListIs1XY1(const ArrayAttr &attr);
// Returns true if the given `op`
// * has an attribute with the given `name`,
// * and the attribute is an integer list of the form [1, X, Y, Z, 1],
// and writes X, Y as 32-bit integer attribute to `x`, `y`, z.
bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x,
IntegerAttr *y, IntegerAttr *z);
// Returns true if every element of the attribute is 1. All elements of `attr`
// must be `IntegerAttr`.
bool TFIntListIsAllOnes(const ArrayAttr &attr);

View File

@ -50,8 +50,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Check number of inputs/outputs.
bool has_bias = node->inputs->size == 3;
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
@ -74,9 +73,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
// Check bias.
const TfLiteTensor* bias = nullptr;
if (has_bias) {
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
const TfLiteTensor* bias = GetInput(context, node, 2);
if (bias) {
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4));
}
@ -120,8 +118,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
bool has_bias = node->inputs->size == 3;
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
const TfLiteTensor* bias = GetInput(context, node, 2);
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,