From f6182b859495e665adf3dc5654a27993acce4c22 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Thu, 1 Aug 2019 18:01:45 -0700 Subject: [PATCH] Add legalization for tf.Any. PiperOrigin-RevId: 261236414 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 27 ++++++ .../compiler/mlir/lite/tests/legalize-tf.mlir | 8 ++ tensorflow/compiler/mlir/lite/tests/ops.mlir | 8 ++ .../mlir/lite/transforms/legalize_patterns.td | 3 + .../mlir/tensorflow/ir/tf_generated_ops.td | 95 ++++++++++++++++++- 5 files changed, 140 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 4e580870eb4..8d06b879748 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -411,6 +411,33 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> { ); } +def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [NoSideEffect]> { + let summary = [{ +Computes the "logical or" of elements across dimensions of a tensor. + }]; + + let description = [{ +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + }]; + + let arguments = (ins + I1Tensor:$input, + I32Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims + ); + + let results = (outs + I1Tensor:$output + ); + + let hasOptions = 1; + let customOption = "ReducerOptions"; +} + def TFL_AveragePool2DOp: TFL_Op<"average_pool_2d", [NoSideEffect, TFL_SameOperandsAndResultsScale]> { let summary = "Average_pool_2d operator"; diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index bce4ad03010..eece76a7935 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -303,6 +303,14 @@ func @abs(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK: %0 = "tfl.abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> } +func @any(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + %0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL:any +// CHECK: %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor +} + func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 24e6fa9fd64..03995f90bbb 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1036,4 +1036,12 @@ func @transpose_1d_perm(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2x2xi32>) -> ten // expected-error @+1 {{op failed to verify that operand 1 is 1-D}} %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> +} + +// ----- + +func @anyWithI64Axis(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + // expected-error @+1 {{tfl.reduce_any' op operand #1 must be tensor of 32-bit integer values}} + %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index ae5b33e324e..a91902a981b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -268,6 +268,9 @@ def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1 def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; +def : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims), + (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>; + def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index ff0b78a2793..0dc23bea59e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -123,6 +123,32 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>, let hasCanonicalizer = 1; } +def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { + let summary = [{ +Computes the "logical or" of elements across dimensions of a tensor. + }]; + + let description = [{ +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + }]; + + let arguments = (ins + I1Tensor:$input, + TF_I32OrI64Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims + ); + + let results = (outs + I1Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; +} + def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> { let summary = [{ Returns the index with the largest value across dimensions of a tensor. @@ -1000,7 +1026,7 @@ Gather slices from `params` into a Tensor with shape specified by `indices`. }]; let description = [{ -`indices` is an K-dimensional integer tensor, best thought of as a +`indices` is a K-dimensional integer tensor, best thought of as a (K-1)-dimensional tensor of indices into `params`, where each element defines a slice of `params`: @@ -1279,8 +1305,10 @@ for dtype in dtype_list: input_tensor, bitwise_ops.invert(input_tensor)), bitwise_ops.invert( tf.constant(0, dtype=dtype))] + expected = tf.constant([0, 0, 0, 0], dtype=tf.float32) tf.assert_equal(tf.cast(not_a_and_a, tf.float32), expected) + expected = tf.cast([not_0] * 4, tf.float32) tf.assert_equal(tf.cast(not_a_or_a, tf.float32), expected) @@ -2847,6 +2875,71 @@ with the following options: "NCHW": `[ batch, channels, height, width ]` "NCHW_VECT_C": `qint8 [ batch, channels / 4, height, width, 4 ]` + +It is useful to consider the operation as transforming a 6-D Tensor. +e.g. for data_format = NHWC, + Each element in the input tensor can be specified via 6 coordinates, + ordered by decreasing memory layout significance as: + n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates + within the output image, bX, bY means coordinates + within the input block, iC means input channels). + The output would be a transpose to the following layout: + n,oY,oX,bY,bX,iC + +This operation is useful for resizing the activations between convolutions +(but keeping all data), e.g. instead of pooling. It is also useful for training +purely convolutional models. + +For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and +block_size = 2: + +``` +x = [[[[1], [2]], + [[3], [4]]]] +``` + +This operation will output a tensor of shape `[1, 1, 1, 4]`: + +``` +[[[[1, 2, 3, 4]]]] +``` + +Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, +the corresponding output will have a single element (i.e. width and height are +both 1) and will have a depth of 4 channels (1 * block_size * block_size). +The output element shape is `[1, 1, 4]`. + +For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. + +``` +x = [[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]] +``` + +This operation, for block_size of 2, will return the following tensor of shape +`[1, 1, 1, 12]` + +``` +[[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] +``` + +Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: + +``` +x = [[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]] +``` + +the operator will return the following tensor of shape `[1 2 2 4]`: + +``` +x = [[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]] +``` }]; let arguments = (ins