Make mhlo.gather -> tf.GatherNd legalization more generic

The new functionality handles the case when start index map isn't an
iota by transposing the operand before passing it into the tf.GatherNd.

PiperOrigin-RevId: 359773845
Change-Id: Ia8b3c7100cd39ed0c7e657a113e55f8081b5c7a3
This commit is contained in:
A. Unique TensorFlower 2021-02-26 09:24:52 -08:00 committed by TensorFlower Gardener
parent 9656b8f589
commit f4448aae4a
2 changed files with 47 additions and 9 deletions

View File

@ -1910,6 +1910,21 @@ func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> t
return %0 : tensor<4x64x128xf32> return %0 : tensor<4x64x128xf32>
} }
// CHECK-LABEL: func @convert_gather_transpose(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x256xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> {
// CHECK: %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32>
// CHECK: %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : {{.*}} -> tensor<4x128xf32>
// CHECK: return %[[VAL_4]]
// CHECK: }
// Test the case when start_index_map isn't an iota what requires a transpose to
// convert it to tf.GatherNd.
func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> {
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<1> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<1> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[128, 1]> : tensor<2xi64>} : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32>
return %0 : tensor<4x128xf32>
}
// CHECK-LABEL: func @convert_dynamic_slice( // CHECK-LABEL: func @convert_dynamic_slice(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>, // CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<i32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<i32>,

View File

@ -1198,18 +1198,22 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
return failure(); return failure();
} }
// Verify that start_index_map and collapsed_slice_dims are both an iota // Verify that start_index_map and collapsed_slice_dims contains the same
// with the same number of elements as the last dimension of start_indices. // values.
auto start_index_map = gather_op.dimension_numbers().start_index_map(); auto start_index_map = gather_op.dimension_numbers().start_index_map();
auto collapsed_slice_dims = auto collapsed_slice_dims =
gather_op.dimension_numbers().collapsed_slice_dims(); gather_op.dimension_numbers().collapsed_slice_dims();
if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) || if (start_index_map.getNumElements() !=
!IsIotaAttr(collapsed_slice_dims, collapsed_slice_dims.getNumElements()) {
start_indices_type.getShape().back())) {
// TODO(tberghammer): Transform start_indices to support non-standard
// start_index_maps.
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
gather_op, "unsupported start index map and/or collapsed slice dims"); gather_op,
"different size for start index map and collapsed slice dims");
}
for (auto c : collapsed_slice_dims) {
if (llvm::count(start_index_map, c) == 0) {
return rewriter.notifyMatchFailure(
gather_op, "collapsed slice dim isn't present in start index map");
}
} }
// Verify that slice_sizes is 1 for the indexed dimensions and the full // Verify that slice_sizes is 1 for the indexed dimensions and the full
@ -1217,7 +1221,7 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
auto slice_sizes = gather_op.slice_sizes(); auto slice_sizes = gather_op.slice_sizes();
int64_t index = 0; int64_t index = 0;
for (int64_t s : slice_sizes.getValues<int64_t>()) { for (int64_t s : slice_sizes.getValues<int64_t>()) {
if (index < start_indices_type.getShape().back()) { if (llvm::count(start_index_map, index)) {
if (s != 1) { if (s != 1) {
return rewriter.notifyMatchFailure(gather_op, return rewriter.notifyMatchFailure(gather_op,
"unsupported slice sizes"); "unsupported slice sizes");
@ -1242,6 +1246,25 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
++offset; ++offset;
} }
// Transpose the operand to handle non-iota start index map.
llvm::SmallVector<int64_t, 4> transpose_dimensions;
llvm::SmallVector<int64_t, 4> transpose_shape;
for (auto s : start_index_map) {
transpose_dimensions.push_back(s.getZExtValue());
transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]);
}
for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
if (llvm::count(start_index_map, i) == 0) {
transpose_dimensions.push_back(i);
transpose_shape.push_back(operand_type.getShape()[i]);
}
}
operand_type =
RankedTensorType::get(transpose_shape, operand_type.getElementType());
operand = rewriter.create<mhlo::TransposeOp>(
gather_op.getLoc(), operand_type, operand,
rewriter.getI64TensorAttr(transpose_dimensions));
rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand, rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
start_indices); start_indices);
return success(); return success();