Use op trait to specify the channel dimension index of the affine ops
Previously the channel dimension index for these ops were specified as utility methods in the fusion pattern. This cl used an op trait to specify this information so it can be used by other passes (for example, per-channel quantization). PiperOrigin-RevId: 275947097 Change-Id: I6a4d3a85696b00507378737bdc2a9a90d8422a1c
This commit is contained in:
parent
7bf81e0bce
commit
807cf30585
@ -221,12 +221,15 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL native op trait for stateful operands.
|
||||
// TFL native op trait for stateful operands and channel indices.
|
||||
|
||||
class StatefulOperands<list<int> operands>
|
||||
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
|
||||
|
||||
|
||||
class ChannelDimIndex<int index>
|
||||
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op base class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -252,8 +255,9 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
string customOption = ?;
|
||||
}
|
||||
|
||||
class TFL_ConvOp<string mnemonic, string opSummary> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> {
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
ChannelDimIndex<index>]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -536,7 +540,7 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Cosine operator";
|
||||
@ -553,7 +557,7 @@ def TFL_CosOp: TFL_Op<"cos", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
}
|
||||
|
||||
def TFL_DepthwiseConv2DOp :
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution"> {
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||
}
|
||||
|
||||
@ -568,7 +572,7 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
|
||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> {
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
let arguments = (ins
|
||||
|
@ -44,6 +44,22 @@ class StatefulOperands {
|
||||
};
|
||||
};
|
||||
|
||||
// The trait to specify the channel dimension index of the input (first operand)
|
||||
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
|
||||
//
|
||||
// class Conv2DOp
|
||||
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
|
||||
//
|
||||
template <int Index>
|
||||
class ChannelDimIndex {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
|
||||
public:
|
||||
static int GetChannelDimIndex() { return Index; }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
|
||||
@ -318,8 +319,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// so we have to update the bias.
|
||||
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
|
||||
|
||||
auto bias_and_slice =
|
||||
ConcreteType::GetBiasDimAndSliceSize(filter_type.getShape());
|
||||
auto bias_and_slice = GetBiasDimAndSliceSize(filter_type.getShape());
|
||||
int64_t bias_size = bias_and_slice.first;
|
||||
int64_t slice_size = bias_and_slice.second;
|
||||
ShapedType new_bias_type =
|
||||
@ -383,6 +383,24 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
}
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the dimension length of the channel dimension and also the slide
|
||||
// size by each position in the channel dimension accordingly. tfl.conv2d and
|
||||
// tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
|
||||
// has tailing channel dimension. This function is to provide a utility to
|
||||
// create the above information from the op property.
|
||||
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
||||
ArrayRef<int64_t> filter_shape) {
|
||||
// Channel dimension index is specified as op property
|
||||
auto channel_index_iter = filter_shape.begin();
|
||||
std::advance(channel_index_iter, AffineOpType::GetChannelDimIndex());
|
||||
// The slide size is the size of the data in higher dimensions.
|
||||
int64_t slice_size =
|
||||
std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
return {*channel_index_iter, slice_size};
|
||||
}
|
||||
};
|
||||
|
||||
class FuseBinaryOpToFollowingFullyConnected
|
||||
@ -394,17 +412,6 @@ class FuseBinaryOpToFollowingFullyConnected
|
||||
FullyConnectedOp>;
|
||||
explicit FuseBinaryOpToFollowingFullyConnected(MLIRContext *context)
|
||||
: BaseType(context) {}
|
||||
|
||||
// The first dimension of the fully-connected weight needs to match the last
|
||||
// dimension of the op result and also the (broadcasted) size of bias. Then
|
||||
// the size of higher-dimensions is considered as the slice size.
|
||||
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
||||
ArrayRef<int64_t> filter_shape) {
|
||||
int64_t depth =
|
||||
std::accumulate(std::next(filter_shape.begin()), filter_shape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
return {filter_shape.front(), depth};
|
||||
}
|
||||
};
|
||||
|
||||
class FuseBinaryOpToFollowingDepthwiseConv2D
|
||||
@ -416,14 +423,6 @@ class FuseBinaryOpToFollowingDepthwiseConv2D
|
||||
DepthwiseConv2DOp>;
|
||||
explicit FuseBinaryOpToFollowingDepthwiseConv2D(MLIRContext *context)
|
||||
: BaseType(context) {}
|
||||
|
||||
// The last dimension of the depthwise conv 2d weight needs to match the last
|
||||
// dimension of the op result and also the (broadcasted) size of bias. Then
|
||||
// slice number is just 1.
|
||||
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
||||
ArrayRef<int64_t> filter_shape) {
|
||||
return {filter_shape.back(), 1};
|
||||
}
|
||||
};
|
||||
|
||||
class FuseBinaryOpToFollowingConv2D
|
||||
@ -434,17 +433,6 @@ class FuseBinaryOpToFollowingConv2D
|
||||
FuseBinaryOpToFollowingAffineOp<FuseBinaryOpToFollowingConv2D, Conv2DOp>;
|
||||
explicit FuseBinaryOpToFollowingConv2D(MLIRContext *context)
|
||||
: BaseType(context) {}
|
||||
|
||||
// The first dimension of the conv 2d weight needs to match the last
|
||||
// dimension of the op result and also the (broadcasted) size of bias. Then
|
||||
// the size of higher-dimensions is considered as the slice size.
|
||||
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
||||
ArrayRef<int64_t> filter_shape) {
|
||||
int64_t depth =
|
||||
std::accumulate(std::next(filter_shape.begin()), filter_shape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
return {filter_shape.front(), depth};
|
||||
}
|
||||
};
|
||||
|
||||
void Optimize::runOnFunction() {
|
||||
|
Loading…
Reference in New Issue
Block a user