Replace the channel dimension index op trait by op interface
PiperOrigin-RevId: 290144065 Change-Id: If77b13f2685b7e2c5ba3f3aa5a44358d13ea1dae
This commit is contained in:
parent
7b9b1de47b
commit
3bcfb829bb
@ -191,7 +191,6 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops.h",
|
||||
"ir/tfl_traits.h",
|
||||
"transforms/passes.h",
|
||||
"utils/attribute_utils.h",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -268,10 +268,20 @@ def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL native op trait for channel indices.
|
||||
// TFL op interface for output channel index.
|
||||
|
||||
class ChannelDimIndex<int index>
|
||||
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
|
||||
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the index of out channel index.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the dimension index of the output channels.}],
|
||||
"int", "GetChannelDimIndex", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op base class.
|
||||
@ -300,7 +310,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
|
||||
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -630,7 +640,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
@ -650,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [
|
||||
def TFL_DepthwiseConv2DOp :
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
int GetChannelDimIndex() { return 3; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
||||
@ -663,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
|
||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface,
|
||||
AffineOpCoefficient<-1, 1>]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
@ -685,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_GatherOp : TFL_Op<"gather", [
|
||||
|
@ -1,47 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace TFL {
|
||||
// The trait to specify the channel dimension index of the input (first operand)
|
||||
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
|
||||
//
|
||||
// class Conv2DOp
|
||||
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
|
||||
//
|
||||
template <int Index>
|
||||
class ChannelDimIndex {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
|
||||
public:
|
||||
static int GetChannelDimIndex() { return Index; }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -366,7 +366,8 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// so we have to update the bias.
|
||||
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
|
||||
|
||||
auto bias_and_slice = GetBiasDimAndSliceSize(filter_type.getShape());
|
||||
auto bias_and_slice =
|
||||
GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
|
||||
int64_t bias_size = bias_and_slice.first;
|
||||
int64_t slice_size = bias_and_slice.second;
|
||||
ShapedType new_bias_type =
|
||||
@ -438,10 +439,10 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// has tailing channel dimension. This function is to provide a utility to
|
||||
// create the above information from the op property.
|
||||
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
||||
ArrayRef<int64_t> filter_shape) {
|
||||
ArrayRef<int64_t> filter_shape, AffineOpType op) {
|
||||
// Channel dimension index is specified as op property
|
||||
auto channel_index_iter = filter_shape.begin();
|
||||
std::advance(channel_index_iter, AffineOpType::GetChannelDimIndex());
|
||||
std::advance(channel_index_iter, op.GetChannelDimIndex());
|
||||
// The slide size is the size of the data in higher dimensions.
|
||||
int64_t slice_size =
|
||||
std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
|
||||
|
Loading…
x
Reference in New Issue
Block a user