diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 2e9191846c1..c0884d19585 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -191,7 +191,6 @@ cc_library( ], hdrs = [ "ir/tfl_ops.h", - "ir/tfl_traits.h", "transforms/passes.h", "utils/attribute_utils.h", "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index d5584cb6687..8ad9aae8c44 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -27,7 +27,6 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e5ac19e2549..116448e70fb 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -268,10 +268,20 @@ def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> { } //===----------------------------------------------------------------------===// -// TFL native op trait for channel indices. +// TFL op interface for output channel index. -class ChannelDimIndex - : ParamNativeOpTrait<"TFL::ChannelDimIndex", !cast(index)>; +def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> { + let description = [{ + Interface for defining the index of out channel index. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the dimension index of the output channels.}], + "int", "GetChannelDimIndex", (ins) + >, + ]; +} //===----------------------------------------------------------------------===// // TFL op base class. @@ -300,7 +310,7 @@ class TFL_Op traits = []> : class TFL_ConvOp : TFL_Op, - ChannelDimIndex, AffineOpCoefficient]> { + TFL_ChannelDimIndexInterface, AffineOpCoefficient]> { let summary = opSummary # " operator"; let description = [{ @@ -630,7 +640,12 @@ def TFL_ExternalConstOp : Op { let results = (outs AnyTensor:$output); } -def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>; +def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> { + let extraClassDeclaration = [{ + // StatefulOpInterface: + int GetChannelDimIndex() { return 0; } + }]; +} def TFL_CosOp: TFL_Op<"cos", [ NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> { @@ -650,6 +665,11 @@ def TFL_CosOp: TFL_Op<"cos", [ def TFL_DepthwiseConv2DOp : TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier)); + + let extraClassDeclaration = [{ + // StatefulOpInterface: + int GetChannelDimIndex() { return 3; } + }]; } def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">; @@ -663,7 +683,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr : // TODO(jpienaar): Update post discussion on semantics of FC OP. def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ - NoSideEffect, AccumulatorUniformScale<2, 0, 1>, ChannelDimIndex<0>, + NoSideEffect, AccumulatorUniformScale<2, 0, 1>, + TFL_ChannelDimIndexInterface, AffineOpCoefficient<-1, 1>]> { let summary = "Fully connected op"; @@ -685,6 +706,11 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let verifier = [{ return Verify(*this); }]; let hasOptions = 1; + + let extraClassDeclaration = [{ + // ChannelDimIndexInterface: + int GetChannelDimIndex() { return 0; } + }]; } def TFL_GatherOp : TFL_Op<"gather", [ diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h deleted file mode 100644 index 5a697664591..00000000000 --- a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the op traits used in the MLIR TensorFlow Lite dialect. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_ - -#include "mlir/IR/OpDefinition.h" -#include "mlir/Support/LLVM.h" // TF:llvm-project - -namespace mlir { -namespace OpTrait { -namespace TFL { -// The trait to specify the channel dimension index of the input (first operand) -// of an affine TFL op (Conv2D, DepthwiseConv2D, FullyConnected). -// -// class Conv2DOp -// : public Op::Impl> { -// -template -class ChannelDimIndex { - public: - template - class Impl : public TraitBase::Impl> { - public: - static int GetChannelDimIndex() { return Index; } - }; -}; - -} // namespace TFL -} // namespace OpTrait -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 3fd1ff2ac94..dc79aa8b07f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -35,7 +35,6 @@ limitations under the License. #include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 2761fa2c85c..39e309a86ff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -366,7 +366,8 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // so we have to update the bias. if (llvm::isa(binary_op)) cst_value.changeSign(); - auto bias_and_slice = GetBiasDimAndSliceSize(filter_type.getShape()); + auto bias_and_slice = + GetBiasDimAndSliceSize(filter_type.getShape(), fc_op); int64_t bias_size = bias_and_slice.first; int64_t slice_size = bias_and_slice.second; ShapedType new_bias_type = @@ -438,10 +439,10 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // has tailing channel dimension. This function is to provide a utility to // create the above information from the op property. static std::pair GetBiasDimAndSliceSize( - ArrayRef filter_shape) { + ArrayRef filter_shape, AffineOpType op) { // Channel dimension index is specified as op property auto channel_index_iter = filter_shape.begin(); - std::advance(channel_index_iter, AffineOpType::GetChannelDimIndex()); + std::advance(channel_index_iter, op.GetChannelDimIndex()); // The slide size is the size of the data in higher dimensions. int64_t slice_size = std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,