Replace the channel dimension index op trait by op interface
PiperOrigin-RevId: 290144065 Change-Id: If77b13f2685b7e2c5ba3f3aa5a44358d13ea1dae
This commit is contained in:
parent
7b9b1de47b
commit
3bcfb829bb
@ -191,7 +191,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ir/tfl_ops.h",
|
"ir/tfl_ops.h",
|
||||||
"ir/tfl_traits.h",
|
|
||||||
"transforms/passes.h",
|
"transforms/passes.h",
|
||||||
"utils/attribute_utils.h",
|
"utils/attribute_utils.h",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
@ -268,10 +268,20 @@ def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL native op trait for channel indices.
|
// TFL op interface for output channel index.
|
||||||
|
|
||||||
class ChannelDimIndex<int index>
|
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||||
: ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast<string>(index)>;
|
let description = [{
|
||||||
|
Interface for defining the index of out channel index.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
[{Returns the dimension index of the output channels.}],
|
||||||
|
"int", "GetChannelDimIndex", (ins)
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL op base class.
|
// TFL op base class.
|
||||||
@ -300,7 +310,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||||||
|
|
||||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||||
ChannelDimIndex<index>, AffineOpCoefficient<index, 1>]> {
|
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>]> {
|
||||||
let summary = opSummary # " operator";
|
let summary = opSummary # " operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -630,7 +640,12 @@ def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
|||||||
let results = (outs AnyTensor:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
int GetChannelDimIndex() { return 0; }
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def TFL_CosOp: TFL_Op<"cos", [
|
def TFL_CosOp: TFL_Op<"cos", [
|
||||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||||
@ -650,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [
|
|||||||
def TFL_DepthwiseConv2DOp :
|
def TFL_DepthwiseConv2DOp :
|
||||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||||
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// StatefulOpInterface:
|
||||||
|
int GetChannelDimIndex() { return 3; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
|
||||||
@ -663,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
|||||||
|
|
||||||
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
// TODO(jpienaar): Update post discussion on semantics of FC OP.
|
||||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>,
|
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||||
|
TFL_ChannelDimIndexInterface,
|
||||||
AffineOpCoefficient<-1, 1>]> {
|
AffineOpCoefficient<-1, 1>]> {
|
||||||
let summary = "Fully connected op";
|
let summary = "Fully connected op";
|
||||||
|
|
||||||
@ -685,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
|||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// ChannelDimIndexInterface:
|
||||||
|
int GetChannelDimIndex() { return 0; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_GatherOp : TFL_Op<"gather", [
|
def TFL_GatherOp : TFL_Op<"gather", [
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// This file defines the op traits used in the MLIR TensorFlow Lite dialect.
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
|
||||||
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace OpTrait {
|
|
||||||
namespace TFL {
|
|
||||||
// The trait to specify the channel dimension index of the input (first operand)
|
|
||||||
// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected).
|
|
||||||
//
|
|
||||||
// class Conv2DOp
|
|
||||||
// : public Op<Conv2DOp, OpTrait::TFL::ChannelDimIndex<0>::Impl> {
|
|
||||||
//
|
|
||||||
template <int Index>
|
|
||||||
class ChannelDimIndex {
|
|
||||||
public:
|
|
||||||
template <typename ConcreteType>
|
|
||||||
class Impl : public TraitBase<ConcreteType, ChannelDimIndex<Index>::Impl> {
|
|
||||||
public:
|
|
||||||
static int GetChannelDimIndex() { return Index; }
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace TFL
|
|
||||||
} // namespace OpTrait
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_
|
|
@ -35,7 +35,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
@ -366,7 +366,8 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
// so we have to update the bias.
|
// so we have to update the bias.
|
||||||
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
|
if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
|
||||||
|
|
||||||
auto bias_and_slice = GetBiasDimAndSliceSize(filter_type.getShape());
|
auto bias_and_slice =
|
||||||
|
GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
|
||||||
int64_t bias_size = bias_and_slice.first;
|
int64_t bias_size = bias_and_slice.first;
|
||||||
int64_t slice_size = bias_and_slice.second;
|
int64_t slice_size = bias_and_slice.second;
|
||||||
ShapedType new_bias_type =
|
ShapedType new_bias_type =
|
||||||
@ -438,10 +439,10 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
// has tailing channel dimension. This function is to provide a utility to
|
// has tailing channel dimension. This function is to provide a utility to
|
||||||
// create the above information from the op property.
|
// create the above information from the op property.
|
||||||
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
|
||||||
ArrayRef<int64_t> filter_shape) {
|
ArrayRef<int64_t> filter_shape, AffineOpType op) {
|
||||||
// Channel dimension index is specified as op property
|
// Channel dimension index is specified as op property
|
||||||
auto channel_index_iter = filter_shape.begin();
|
auto channel_index_iter = filter_shape.begin();
|
||||||
std::advance(channel_index_iter, AffineOpType::GetChannelDimIndex());
|
std::advance(channel_index_iter, op.GetChannelDimIndex());
|
||||||
// The slide size is the size of the data in higher dimensions.
|
// The slide size is the size of the data in higher dimensions.
|
||||||
int64_t slice_size =
|
int64_t slice_size =
|
||||||
std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
|
std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user