From f4448aae4af6e3dd90d49468e1cc977d705163d7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Feb 2021 09:24:52 -0800 Subject: [PATCH] Make mhlo.gather -> tf.GatherNd legalization more generic The new functionality handles the case when start index map isn't an iota by transposing the operand before passing it into the tf.GatherNd. PiperOrigin-RevId: 359773845 Change-Id: Ia8b3c7100cd39ed0c7e657a113e55f8081b5c7a3 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 15 +++++++ .../tensorflow/transforms/legalize_hlo.cc | 41 +++++++++++++++---- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 120eee547ba..9f9fa73302d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1910,6 +1910,21 @@ func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> t return %0 : tensor<4x64x128xf32> } +// CHECK-LABEL: func @convert_gather_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32> +// CHECK: %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : {{.*}} -> tensor<4x128xf32> +// CHECK: return %[[VAL_4]] +// CHECK: } +// Test the case when start_index_map isn't an iota what requires a transpose to +// convert it to tf.GatherNd. +func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> { + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<1> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<1> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[128, 1]> : tensor<2xi64>} : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32> + return %0 : tensor<4x128xf32> +} + // CHECK-LABEL: func @convert_dynamic_slice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 5bda6b51a3f..255e7d9dffd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -1198,18 +1198,22 @@ class ConvertGatherOp : public OpConversionPattern { return failure(); } - // Verify that start_index_map and collapsed_slice_dims are both an iota - // with the same number of elements as the last dimension of start_indices. + // Verify that start_index_map and collapsed_slice_dims contains the same + // values. auto start_index_map = gather_op.dimension_numbers().start_index_map(); auto collapsed_slice_dims = gather_op.dimension_numbers().collapsed_slice_dims(); - if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) || - !IsIotaAttr(collapsed_slice_dims, - start_indices_type.getShape().back())) { - // TODO(tberghammer): Transform start_indices to support non-standard - // start_index_maps. + if (start_index_map.getNumElements() != + collapsed_slice_dims.getNumElements()) { return rewriter.notifyMatchFailure( - gather_op, "unsupported start index map and/or collapsed slice dims"); + gather_op, + "different size for start index map and collapsed slice dims"); + } + for (auto c : collapsed_slice_dims) { + if (llvm::count(start_index_map, c) == 0) { + return rewriter.notifyMatchFailure( + gather_op, "collapsed slice dim isn't present in start index map"); + } } // Verify that slice_sizes is 1 for the indexed dimensions and the full @@ -1217,7 +1221,7 @@ class ConvertGatherOp : public OpConversionPattern { auto slice_sizes = gather_op.slice_sizes(); int64_t index = 0; for (int64_t s : slice_sizes.getValues()) { - if (index < start_indices_type.getShape().back()) { + if (llvm::count(start_index_map, index)) { if (s != 1) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); @@ -1242,6 +1246,25 @@ class ConvertGatherOp : public OpConversionPattern { ++offset; } + // Transpose the operand to handle non-iota start index map. + llvm::SmallVector transpose_dimensions; + llvm::SmallVector transpose_shape; + for (auto s : start_index_map) { + transpose_dimensions.push_back(s.getZExtValue()); + transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]); + } + for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) { + if (llvm::count(start_index_map, i) == 0) { + transpose_dimensions.push_back(i); + transpose_shape.push_back(operand_type.getShape()[i]); + } + } + operand_type = + RankedTensorType::get(transpose_shape, operand_type.getElementType()); + operand = rewriter.create( + gather_op.getLoc(), operand_type, operand, + rewriter.getI64TensorAttr(transpose_dimensions)); + rewriter.replaceOpWithNewOp(gather_op, result_type, operand, start_indices); return success();