diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 131b7e7410f..efb4dccb0f8 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -221,12 +221,15 @@ def TFL_ComparisonBinaryBuilder : OpBuilder< }]>; //===----------------------------------------------------------------------===// -// TFL native op trait for stateful operands. +// TFL native op trait for stateful operands and channel indices. class StatefulOperands operands> : ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt.result>; +class ChannelDimIndex + : ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast(index)>; + //===----------------------------------------------------------------------===// // TFL op base class. //===----------------------------------------------------------------------===// @@ -252,8 +255,9 @@ class TFL_Op traits = []> : string customOption = ?; } -class TFL_ConvOp : - TFL_Op]> { +class TFL_ConvOp : + TFL_Op, + ChannelDimIndex]> { let summary = opSummary # " operator"; let description = [{ @@ -536,7 +540,7 @@ def TFL_ConstOp : Op; +def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>; def TFL_CosOp: TFL_Op<"cos", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Cosine operator"; @@ -553,7 +557,7 @@ def TFL_CosOp: TFL_Op<"cos", [NoSideEffect, SameOperandsAndResultType]> { } def TFL_DepthwiseConv2DOp : - TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution"> { + TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier)); } @@ -568,7 +572,7 @@ def TFL_FullyConnectedOptionsWeightFormatAttr : // TODO(jpienaar): Update post discussion on semantics of FC OP. def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ - NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> { + NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>]> { let summary = "Fully connected op"; let arguments = (ins diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h index af8c707a04e..0ec63531658 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h @@ -44,6 +44,22 @@ class StatefulOperands { }; }; +// The trait to specify the channel dimension index of the input (first operand) +// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected). +// +// class Conv2DOp +// : public Op::Impl> { +// +template +class ChannelDimIndex { + public: + template + class Impl : public TraitBase::Impl> { + public: + static int GetChannelDimIndex() { return Index; } + }; +}; + } // namespace TFL } // namespace OpTrait } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 8a121fb879b..7ff84f16509 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -318,8 +319,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // so we have to update the bias. if (llvm::isa(binary_op)) cst_value.changeSign(); - auto bias_and_slice = - ConcreteType::GetBiasDimAndSliceSize(filter_type.getShape()); + auto bias_and_slice = GetBiasDimAndSliceSize(filter_type.getShape()); int64_t bias_size = bias_and_slice.first; int64_t slice_size = bias_and_slice.second; ShapedType new_bias_type = @@ -383,6 +383,24 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { } return this->matchSuccess(); } + + private: + // Returns the dimension length of the channel dimension and also the slide + // size by each position in the channel dimension accordingly. tfl.conv2d and + // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d + // has tailing channel dimension. This function is to provide a utility to + // create the above information from the op property. + static std::pair GetBiasDimAndSliceSize( + ArrayRef filter_shape) { + // Channel dimension index is specified as op property + auto channel_index_iter = filter_shape.begin(); + std::advance(channel_index_iter, AffineOpType::GetChannelDimIndex()); + // The slide size is the size of the data in higher dimensions. + int64_t slice_size = + std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1, + std::multiplies()); + return {*channel_index_iter, slice_size}; + } }; class FuseBinaryOpToFollowingFullyConnected @@ -394,17 +412,6 @@ class FuseBinaryOpToFollowingFullyConnected FullyConnectedOp>; explicit FuseBinaryOpToFollowingFullyConnected(MLIRContext *context) : BaseType(context) {} - - // The first dimension of the fully-connected weight needs to match the last - // dimension of the op result and also the (broadcasted) size of bias. Then - // the size of higher-dimensions is considered as the slice size. - static std::pair GetBiasDimAndSliceSize( - ArrayRef filter_shape) { - int64_t depth = - std::accumulate(std::next(filter_shape.begin()), filter_shape.end(), 1, - std::multiplies()); - return {filter_shape.front(), depth}; - } }; class FuseBinaryOpToFollowingDepthwiseConv2D @@ -416,14 +423,6 @@ class FuseBinaryOpToFollowingDepthwiseConv2D DepthwiseConv2DOp>; explicit FuseBinaryOpToFollowingDepthwiseConv2D(MLIRContext *context) : BaseType(context) {} - - // The last dimension of the depthwise conv 2d weight needs to match the last - // dimension of the op result and also the (broadcasted) size of bias. Then - // slice number is just 1. - static std::pair GetBiasDimAndSliceSize( - ArrayRef filter_shape) { - return {filter_shape.back(), 1}; - } }; class FuseBinaryOpToFollowingConv2D @@ -434,17 +433,6 @@ class FuseBinaryOpToFollowingConv2D FuseBinaryOpToFollowingAffineOp; explicit FuseBinaryOpToFollowingConv2D(MLIRContext *context) : BaseType(context) {} - - // The first dimension of the conv 2d weight needs to match the last - // dimension of the op result and also the (broadcasted) size of bias. Then - // the size of higher-dimensions is considered as the slice size. - static std::pair GetBiasDimAndSliceSize( - ArrayRef filter_shape) { - int64_t depth = - std::accumulate(std::next(filter_shape.begin()), filter_shape.end(), 1, - std::multiplies()); - return {filter_shape.front(), depth}; - } }; void Optimize::runOnFunction() {