Replace the channel dimension index op trait by op interface

PiperOrigin-RevId: 290144065
Change-Id: If77b13f2685b7e2c5ba3f3aa5a44358d13ea1dae
This commit is contained in:
Feng Liu 2020-01-16 14:19:50 -08:00 committed by TensorFlower Gardener
parent 7b9b1de47b
commit 3bcfb829bb
6 changed files with 36 additions and 59 deletions

View File

@ -191,7 +191,6 @@ cc_library(
],
hdrs = [
"ir/tfl_ops.h",
"ir/tfl_traits.h",
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -268,10 +268,20 @@ def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
}
//===----------------------------------------------------------------------===//
// TFL native op trait for channel indices.
// TFL op interface for output channel index.
class ChannelDimIndex<int index>
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op base class.
@ -300,7 +310,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
let summary = opSummary # " operator";
let description = [{
@ -630,7 +640,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
}
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
@ -650,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [
def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
let extraClassDeclaration = [{
// StatefulOpInterface:
int GetChannelDimIndex() { return 3; }
}];
}
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
@ -663,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
// TODO(jpienaar): Update post discussion on semantics of FC OP.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>]> {
let summary = "Fully connected op";
@ -685,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
}];
}
def TFL_GatherOp : TFL_Op<"gather", [

View File

@ -1,47 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h" // TF:llvm-project
namespace mlir {
namespace OpTrait {
namespace TFL {
// The trait to specify the channel dimension index of the input (first operand)
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
//
// class Conv2DOp
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
//
template <int Index>
class ChannelDimIndex {
public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
public:
static int GetChannelDimIndex() { return Index; }
};
};
} // namespace TFL
} // namespace OpTrait
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_

View File

@ -35,7 +35,6 @@ limitations under the License.
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -366,7 +366,8 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// so we have to update the bias.
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
auto bias_and_slice = GetBiasDimAndSliceSize(filter_type.getShape());
auto bias_and_slice =
GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
int64_t bias_size = bias_and_slice.first;
int64_t slice_size = bias_and_slice.second;
ShapedType new_bias_type =
@ -438,10 +439,10 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// has tailing channel dimension. This function is to provide a utility to
// create the above information from the op property.
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
ArrayRef<int64_t> filter_shape) {
ArrayRef<int64_t> filter_shape, AffineOpType op) {
// Channel dimension index is specified as op property
auto channel_index_iter = filter_shape.begin();
std::advance(channel_index_iter, AffineOpType::GetChannelDimIndex());
std::advance(channel_index_iter, op.GetChannelDimIndex());
// The slide size is the size of the data in higher dimensions.
int64_t slice_size =
std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,