- Add layout optimization passes to TFLite passes. This will change data layout for some ops (e.g. conv2d) so it will change the layout to NHWC and can be converted to tflite.

- Update Transpose legalization from TF -> TFL to handle int64 permutation.
- Update unit-test of legalization to also do CSE to simplify some tests and avoid duplicate tensors which makes problems when matching params (which one to match)

PiperOrigin-RevId: 331672423
Change-Id: I39d943d197bb0288bf441f0c4dd80b9c9deb8257
This commit is contained in:
Karim Nosir 2020-09-14 18:19:05 -07:00 committed by TensorFlower Gardener
parent 8dee182572
commit d8f9b4a0d2
7 changed files with 307 additions and 12 deletions

View File

@ -407,6 +407,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",

View File

@ -0,0 +1,232 @@
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_0 -print-function-result-mapping %s -o - 2>&1 | FileCheck %s
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 8
}
dim {
size: 8
}
dim {
size: 2
}
}
}
}
}
node {
name: "conv_net_2d/conv_2d_0/w"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 3
}
dim {
size: 3
}
dim {
size: 2
}
dim {
size: 2
}
}
tensor_content: ";;\177<5\241i\275\312f\211>#\346j>\033W\325\275\253>\210=Vr\r\276\304\222\313\276\374\346\214>\016e\211>)\253\000>\3241\337\275\235g-\276*(\216\276\326#\367\274\023\213\300\276\227\031\206>PUF=\253\330\263<\337IL\276\334\320\215>\377\306v\276\372C\302\273baM>H\314\270<2\221\352=J\026{\276\221\243\245\276?\314\240=UW2\2755\207\253\274\256\207\333\273\335\372\227>\246\232;\276%\r\374<Z\346\204>"
}
}
}
}
node {
name: "conv_net_2d/conv_2d_0/w/read"
op: "Identity"
input: "conv_net_2d/conv_2d_0/w"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@conv_net_2d/conv_2d_0/w"
}
}
}
}
node {
name: "conv_net_2d_1/conv_2d_0/convolution"
op: "Conv2D"
input: "input"
input: "conv_net_2d/conv_2d_0/w/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NCHW"
}
}
attr {
key: "dilations"
value {
list {
i: 1
i: 1
i: 1
i: 1
}
}
}
attr {
key: "explicit_paddings"
value {
list {
}
}
}
attr {
key: "padding"
value {
s: "SAME"
}
}
attr {
key: "strides"
value {
list {
i: 1
i: 1
i: 1
i: 1
}
}
}
attr {
key: "use_cudnn_on_gpu"
value {
b: true
}
}
}
node {
name: "conv_net_2d/conv_2d_0/b"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\315\314\314=\315\314\314="
}
}
}
}
node {
name: "conv_net_2d/conv_2d_0/b/read"
op: "Identity"
input: "conv_net_2d/conv_2d_0/b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@conv_net_2d/conv_2d_0/b"
}
}
}
}
node {
name: "conv_net_2d_1/conv_2d_0/BiasAdd"
op: "BiasAdd"
input: "conv_net_2d_1/conv_2d_0/convolution"
input: "conv_net_2d/conv_2d_0/b/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NHWC"
}
}
}
node {
name: "conv_net_2d_1/Relu"
op: "Relu"
input: "conv_net_2d_1/conv_2d_0/BiasAdd"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "output_0"
op: "Identity"
input: "conv_net_2d_1/Relu"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}
# CHECK: 'main' inputs:
# CHECK-NEXT: name: 'input'
# CHECK-NEXT: 'main' outputs:
# CHECK-NEXT: name: 'output_0'

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tfl-legalize-tf | FileCheck %s
// RUN: tf-opt %s -tfl-legalize-tf --cse | FileCheck %s
func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
@ -196,7 +196,6 @@ func @shape(%arg0: tensor<?x1001xf32>) -> tensor<2xi32> {
// CHECK-LABEL: shape
// CHECK: "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
// CHECK: %1 = "tfl.shape"(%arg0) : (tensor<?x1001xf32>) -> tensor<2xi32>
}
func @fill(%arg0: tensor<3xi32>, %arg1: tensor<f32>) -> tensor<?x?x?xf32> {
@ -719,9 +718,8 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
}
@ -753,9 +751,8 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32>
// CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32>
// CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32>
// CHECK: return [[VAL_4]] : tensor<8x16x16xf32>
}
@ -1047,8 +1044,7 @@ func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) ->
// CHECK-LABEL: matmul_transposed_a
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_2:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
@ -1359,10 +1355,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_0:.*]] = constant unit
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_2:.*]] = constant unit
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32>
// CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32>
}
@ -1568,3 +1561,27 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
// CHECK-LABEL: add_with_int32_5d_inputs
// CHECK: "tf.Add"(%arg0, %arg1)
}
func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
%cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
%0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
// CHECK-LABEL: tranpose_int32_perm
// CHECK: "tfl.transpose"
}
func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
%cst = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64>
%0 = "tf.Transpose"(%arg0, %cst): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
// CHECK-LABEL: tranpose_int64_perm
// CHECK: "tfl.transpose"
}
func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> {
%0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
// CHECK-LABEL: tranpose_arg
// CHECK: "tfl.transpose"
}

View File

@ -1,4 +1,5 @@
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NHWC -tfl-prepare-tf | FileCheck --check-prefix=LAYOUT --dump-input=always %s
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
@ -53,6 +54,15 @@ func @depthwiseConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x4xf32>, tensor<256x3
// CHECK: %5 = "tf.DepthwiseConv2dNative"
}
func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32> {
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x3x32x32xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x30xf32>
return %0 : tensor<256x16x30x30xf32>
// LAYOUT-LABEL: Conv2dNCHW
// LAYOUT: "tfl.conv_2d"
}
func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK

View File

@ -38,6 +38,10 @@ CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
namespace {
// Data layout supported by TFLite.
const char kTFLiteDataLayout[] = "NHWC";
} // namespace
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
@ -170,6 +174,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// Force layout supported by TFLite, this will transpose the data
// to match 'kTFLiteDataLayout'
mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
layout_optimization_options.force_data_format = kTFLiteDataLayout;
mlir::TF::CreateLayoutOptimizationPipeline(*pass_manager,
layout_optimization_options);
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addPass(

View File

@ -27,6 +27,9 @@ def NonOpaqueElementsAttr : ElementsAttrBase<
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
def Int64ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isInteger(64)">, "Int 64 constant tensor">;
// Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr
// with builder.
class ExtractI32At<int i> : NativeCodeCall<
@ -50,6 +53,10 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
def ExtractSingleElementAsInt32 : NativeCodeCall<
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
// Converts tensor with int64 to int32.
def CreateCastToInt32 : NativeCodeCall<
"CreateCastToInt32($0, $_loc, $_builder)">;
// Checks whether the given operation has static shapes and same shapes of all inputs.
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
@ -208,8 +215,14 @@ def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0),
def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims),
(TFL_SqueezeOp $arg, $squeeze_dims)>;
def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def LegalizeTransposeInt64 : Pat<
(TF_TransposeOp $arg, (ConstantOp Int64ElementsAttr:$perm)),
(TFL_TransposeOp $arg, (CreateCastToInt32 $perm))>;
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
(TFL_TransposeOp $arg, $perm)>;
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;

View File

@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -116,6 +117,17 @@ bool HasSameStaticShapes(Operation* op) {
return true;
}
// Util that casts 'val' to Int32 by adding a cast Op.
Value CreateCastToInt32(Attribute val, Location loc,
PatternRewriter& rewriter) {
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
IntegerType new_ele_type = rewriter.getIntegerType(32);
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
return rewriter.create<TF::CastOp>(loc, new_type,
rewriter.create<TF::ConstOp>(loc, val),
rewriter.getBoolAttr(false));
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
#define DECL_CONVERT_OP(tf_op) \