Support Conv3d in MLIR converter
PiperOrigin-RevId: 354876135 Change-Id: I93710001e202a8b8ea7e9cbc3a1983e30fd7dcfe
This commit is contained in:
parent
5250b2e620
commit
6b23dbc159
@ -195,6 +195,10 @@ def TFL_StatefulTensor : TypeAlias<AnyTensor, "stateful tensor">;
|
||||
// Rank/Shape helpers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns true of operand is none type.
|
||||
class TFL_OperandIsNoneType<int i> :
|
||||
CPred<"$_op.getOperand(" # i # ").getType().isa<NoneType>()">;
|
||||
|
||||
class TFL_OperandIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
@ -256,6 +260,44 @@ class TFL_Operand0DOr1ElementTensor<int x> :
|
||||
Or<[TFL_OperandHasKnownRank<x, 0>,
|
||||
And<[TFL_OperandHasKnownRank<x, 1>, TFL_OperandDimEquals<x, 0, 1>]>]>>;
|
||||
|
||||
// Return true if i-th dim of x-th operand is the same as j-th dim of y-th
|
||||
// operand or any of those operands does not have static shape.
|
||||
class TFL_OperandsHaveSameDims<int x, int y, int i, int j> :
|
||||
Or<[TFL_OperandIsUnrankedPred<x>,
|
||||
TFL_OperandIsUnrankedPred<y>,
|
||||
CPred<"!$_op.getOperand(" # x #
|
||||
").getType().cast<ShapedType>().hasStaticShape()">,
|
||||
CPred<"!$_op.getOperand(" # y #
|
||||
").getType().cast<ShapedType>().hasStaticShape()">,
|
||||
CPred<"$_op.getOperand(" # x #
|
||||
").getType().cast<ShapedType>().getShape()[" # i # "] == "
|
||||
"$_op.getOperand(" # y #
|
||||
").getType().cast<ShapedType>().getShape()[" # j # "]">]>;
|
||||
|
||||
class TFL_OperandsHaveSameDimsTrait<int x, int y, int i, int j> :
|
||||
PredOpTrait<"dim " # i # " of operand " # x # " equals to dim " # j #
|
||||
" of operand " # y,
|
||||
TFL_OperandsHaveSameDims<x, y, i, j>>;
|
||||
|
||||
// Return true if number of elements of x-th operand is the same as j-th dim of
|
||||
// y-th operand or any of those operands does not have static shape.
|
||||
class TFL_NumElementsEqualsDim<int x, int y, int j> :
|
||||
Or<[TFL_OperandIsUnrankedPred<x>,
|
||||
TFL_OperandIsUnrankedPred<y>,
|
||||
CPred<"!$_op.getOperand(" # x #
|
||||
").getType().cast<ShapedType>().hasStaticShape()">,
|
||||
CPred<"!$_op.getOperand(" # y #
|
||||
").getType().cast<ShapedType>().hasStaticShape()">,
|
||||
CPred<"$_op.getOperand(" # x #
|
||||
").getType().cast<ShapedType>().getNumElements() == "
|
||||
"$_op.getOperand(" # y #
|
||||
").getType().cast<ShapedType>().getShape()[" # j # "]">]>;
|
||||
|
||||
class TFL_NumElementsEqualsDimTrait<int x, int y, int j> :
|
||||
PredOpTrait<"operand " # x # " has num of elements equals to dim " # j #
|
||||
" of operand " # y,
|
||||
TFL_NumElementsEqualsDim<x, y, j>>;
|
||||
|
||||
// tf.uint8 and tf.quint8 are mapped to the same tflite types, so they are equal
|
||||
// when used as element types.
|
||||
class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
||||
@ -275,7 +317,7 @@ class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
|
||||
class TFL_OperandIsNoneOrHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||
TFL_OperandIsNoneType<n>,
|
||||
TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
@ -283,7 +325,7 @@ class TFL_OperandIsNoneOrHasRank<int n, int m> :
|
||||
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||
TFL_OperandIsNoneType<n>,
|
||||
TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
@ -4556,7 +4598,6 @@ the dimension is padded with zeros.
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
def TFL_AssignVariableOp : TFL_Op<"assign_variable", []> {
|
||||
let summary = "Assigns a new value to a variable.";
|
||||
|
||||
@ -4593,4 +4634,52 @@ Read variable data identified by 'resource_id'.
|
||||
let results = (outs TFL_TensorOf<[F32]>:$result);
|
||||
}
|
||||
|
||||
def TFL_Conv3DOp : TFL_Op<"conv_3d", [
|
||||
NoSideEffect,
|
||||
AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_OperandHasRank<0, 5>,
|
||||
TFL_OperandHasRank<1, 5>,
|
||||
// Channel dimension in input and filter should match.
|
||||
TFL_OperandsHaveSameDimsTrait<0, 1, 4, 3>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"bias and output must have same element type",
|
||||
Or<[
|
||||
TFL_OperandIsNoneType<2>,
|
||||
TFL_TCresVTEtIsSameAsOp<0, 2>]>>,
|
||||
PredOpTrait<"bias must has num of elements equals to 4th dim of filter",
|
||||
Or<[
|
||||
TFL_OperandIsNoneType<2>,
|
||||
TFL_NumElementsEqualsDim<2, 1, 4>]>>]> {
|
||||
let summary = "Convolution 3D operator";
|
||||
|
||||
let description = [{
|
||||
Performs convolution operation on 3D inputs.
|
||||
Inputs:
|
||||
`inputs[0]`: required: the input activation tensor
|
||||
`inputs[1]`: required: the filter weight tensor
|
||||
`inputs[2]`: optional: the bias tensor
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32]>:$input,
|
||||
TFL_TensorOf<[F32]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32]>:$bias,
|
||||
I32Attr:$dilation_d_factor,
|
||||
I32Attr:$dilation_h_factor,
|
||||
I32Attr:$dilation_w_factor,
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_d,
|
||||
I32Attr:$stride_h,
|
||||
I32Attr:$stride_w
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let customOption = "Conv3DOptions";
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -1996,3 +1996,22 @@ func @rfft2d_invalid(%arg0: tensor<10x20x10x30xf64>, %arg1: tensor<2xi32>) -> te
|
||||
// CHECK-LABEL: rfft2d_invalid
|
||||
// CHECK-NOT: "tfl.RFFT2D"
|
||||
}
|
||||
|
||||
|
||||
func @conv3d_valid(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
|
||||
%0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0: tensor<?x?x?x?x?xf32>
|
||||
|
||||
// CHECK-LABEL: conv3d_valid
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: [[BCT:%.*]] = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, none) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
func @conv3d_invalid_strides(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
|
||||
%0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0: tensor<?x?x?x?x?xf32>
|
||||
// CHECK-LABEL: conv3d_invalid_strides
|
||||
// CHECK: [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
@ -2520,3 +2520,51 @@ func @testFillWithQI8(%arg0: tensor<1x4xi32>, %arg1: tensor<? x !quant.uniform<i
|
||||
%0 = "tfl.fill"(%arg0, %arg1): (tensor<1x4xi32>, tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
return %0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testConv3dWithFloatInput
|
||||
func @testConv3dWithFloatInput(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>,%arg2: tensor<?xf32>) -> tensor<?x?x?x?x?xf32> {
|
||||
// CHECK: "tfl.conv_3d"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConv3dNoBiasInput
|
||||
func @testConv3dNoBiasInput(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?x?xf32>,%arg2: none) -> tensor<?x?x?x?x?xf32> {
|
||||
// CHECK: "tfl.conv_3d"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, none) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConv3dInvalidFilterShape(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2x2x3x3xf32>,%arg2: tensor<?xf32>) -> tensor<?x?x?x?x?xf32> {
|
||||
// expected-error @+1 {{failed to verify that dim 4 of operand 0 equals to dim 3 of operand 1}}
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x3x3xf32>, tensor<?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConv3dInvalidBiasShape(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2x2x2x3xf32>,%arg2: tensor<4xf32>) -> tensor<?x?x?x?x?xf32> {
|
||||
// expected-error @+1 {{failed to verify that bias must has num of elements equals to 4th dim of filter}}
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<4xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConv3dMisMatchInputType(%arg0: tensor<2x3x4x5x2xi32>,%arg1: tensor<2x2x2x2x3xf32>,%arg2: tensor<3xf32>) -> tensor<?x?x?x?x?xf32> {
|
||||
// expected-error @+1 {{op failed to verify that input and output must have same element type}}
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xi32>, tensor<2x2x2x2x3xf32>, tensor<3xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testConv3dMisMatchBiasType(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2x2x2x3xf32>,%arg2: tensor<3xi32>) -> tensor<?x?x?x?x?xf32> {
|
||||
// expected-error @+1 {{failed to verify that bias and output must have same element type}}
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<3xi32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
@ -150,6 +150,7 @@ DECL_CONVERT_OP(Split);
|
||||
DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(RandomUniform);
|
||||
DECL_CONVERT_OP(Conv3D);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
@ -338,6 +339,46 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFConv3DOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (!TFDataFormatIsNDHWC(op)) return failure();
|
||||
|
||||
auto tf_op = cast<TF::Conv3DOp>(op);
|
||||
|
||||
IntegerAttr stride_depth, stride_height, stride_width;
|
||||
if (!TFIntListIs1XYZ1(op, "strides", &stride_depth, &stride_height,
|
||||
&stride_width))
|
||||
return failure();
|
||||
|
||||
IntegerAttr dilation_depth_factor, dilation_height_factor,
|
||||
dilation_width_factor;
|
||||
if (!TFIntListIs1XYZ1(op, "dilations", &dilation_depth_factor,
|
||||
&dilation_height_factor, &dilation_width_factor)) {
|
||||
// If the 'dilations' attribute is missing, we use the default value (1)
|
||||
// for all dilation depth, height and width factor.
|
||||
dilation_depth_factor = rewriter.getI32IntegerAttr(1);
|
||||
dilation_height_factor = rewriter.getI32IntegerAttr(1);
|
||||
dilation_width_factor = rewriter.getI32IntegerAttr(1);
|
||||
}
|
||||
|
||||
StringAttr padding;
|
||||
if (!TFPaddingIsSameOrValid(op, &padding)) return failure();
|
||||
|
||||
// TensorFlow Conv3D has no bias, optimization patterns will fuse Conv3D
|
||||
// with other ops can fill the bias.
|
||||
Value none = rewriter.create<mlir::ConstantOp>(
|
||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||
|
||||
rewriter.replaceOpWithNewOp<TFL::Conv3DOp>(
|
||||
op, tf_op.getType(), tf_op.input(), tf_op.filter(),
|
||||
/*bias=*/none, dilation_depth_factor, dilation_height_factor,
|
||||
dilation_width_factor,
|
||||
/*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding,
|
||||
stride_depth, stride_height, stride_width);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// MatrixDiagV3 is MatrixDiagV2 with an alignment attribute. This attribute
|
||||
// only has effects when processing multiple diagonals. Since TFLite converts
|
||||
// MatrixDiagV{2,3} to MatrixDiag, which only takes single-diagonal inputs, we
|
||||
@ -692,7 +733,7 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
|
||||
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
|
||||
ConvertTFRandomUniformOp>(context);
|
||||
ConvertTFRandomUniformOp, ConvertTFConv3DOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
|
@ -61,6 +61,33 @@ bool TFIntListIs1XY1(const ArrayAttr &attr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the given `op`
|
||||
// * has an attribute with the given `name`,
|
||||
// * and the attribute is an integer list of the form [1, X, Y, Z, 1],
|
||||
// and writes X, Y as 32-bit integer attribute to `x`, `y`, z.
|
||||
bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x,
|
||||
IntegerAttr *y, IntegerAttr *z) {
|
||||
auto attr = op->getAttrOfType<ArrayAttr>(name);
|
||||
if (!attr) return false;
|
||||
|
||||
auto elements = attr.getValue();
|
||||
if (elements.size() != 5 ||
|
||||
std::any_of(elements.begin(), elements.end(),
|
||||
[](Attribute e) { return !e.isa<IntegerAttr>(); }))
|
||||
return false;
|
||||
|
||||
if (elements.front().cast<IntegerAttr>().getInt() != 1 ||
|
||||
elements.back().cast<IntegerAttr>().getInt() != 1)
|
||||
return false;
|
||||
|
||||
Builder b(op->getContext());
|
||||
*x = b.getI32IntegerAttr(elements[1].cast<IntegerAttr>().getInt());
|
||||
*y = b.getI32IntegerAttr(elements[2].cast<IntegerAttr>().getInt());
|
||||
*z = b.getI32IntegerAttr(elements[3].cast<IntegerAttr>().getInt());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if every element of the attribute is 1. All elements of `attr`
|
||||
// must be `IntegerAttr`.
|
||||
bool TFIntListIsAllOnes(const ArrayAttr &attr) {
|
||||
|
@ -35,6 +35,14 @@ inline bool TFDataFormatIsNHWC(Operation *op) {
|
||||
return !attr || attr.getValue() == "NHWC";
|
||||
}
|
||||
|
||||
// Returns true if the given TensorFlow op does not have a `data_format`
|
||||
// attribute (then default to "NDHWC"), or its `data_format` attribute is
|
||||
// "NDHWC".
|
||||
inline bool TFDataFormatIsNDHWC(Operation *op) {
|
||||
auto attr = op->getAttrOfType<StringAttr>("data_format");
|
||||
return !attr || attr.getValue() == "NDHWC";
|
||||
}
|
||||
|
||||
// Returns true if the given `op`
|
||||
// * has an attribute with the given `name`,
|
||||
// * and the attribute is an integer list of the form [1, X, Y, 1],
|
||||
@ -45,6 +53,13 @@ bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x,
|
||||
// Returns true if the attribute is an integer list of the form [1, X, Y, 1],
|
||||
bool TFIntListIs1XY1(const ArrayAttr &attr);
|
||||
|
||||
// Returns true if the given `op`
|
||||
// * has an attribute with the given `name`,
|
||||
// * and the attribute is an integer list of the form [1, X, Y, Z, 1],
|
||||
// and writes X, Y as 32-bit integer attribute to `x`, `y`, z.
|
||||
bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x,
|
||||
IntegerAttr *y, IntegerAttr *z);
|
||||
|
||||
// Returns true if every element of the attribute is 1. All elements of `attr`
|
||||
// must be `IntegerAttr`.
|
||||
bool TFIntListIsAllOnes(const ArrayAttr &attr);
|
||||
|
@ -50,8 +50,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
// Check number of inputs/outputs.
|
||||
bool has_bias = node->inputs->size == 3;
|
||||
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
|
||||
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
@ -74,9 +73,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
|
||||
|
||||
// Check bias.
|
||||
const TfLiteTensor* bias = nullptr;
|
||||
if (has_bias) {
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
|
||||
const TfLiteTensor* bias = GetInput(context, node, 2);
|
||||
if (bias) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4));
|
||||
}
|
||||
@ -120,8 +118,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
|
||||
const TfLiteTensor* filter;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
|
||||
bool has_bias = node->inputs->size == 3;
|
||||
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
|
||||
const TfLiteTensor* bias = GetInput(context, node, 2);
|
||||
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
|
Loading…
Reference in New Issue
Block a user