Make mhlo.gather -> tf.GatherNd legalization more generic
The new functionality handles the case when start index map isn't an iota by transposing the operand before passing it into the tf.GatherNd. PiperOrigin-RevId: 359773845 Change-Id: Ia8b3c7100cd39ed0c7e657a113e55f8081b5c7a3
This commit is contained in:
parent
9656b8f589
commit
f4448aae4a
@ -1910,6 +1910,21 @@ func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> t
|
|||||||
return %0 : tensor<4x64x128xf32>
|
return %0 : tensor<4x64x128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_gather_transpose(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x256xf32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) : {{.*}} -> tensor<4x128xf32>
|
||||||
|
// CHECK: return %[[VAL_4]]
|
||||||
|
// CHECK: }
|
||||||
|
// Test the case when start_index_map isn't an iota what requires a transpose to
|
||||||
|
// convert it to tf.GatherNd.
|
||||||
|
func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> {
|
||||||
|
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<1> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<1> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[128, 1]> : tensor<2xi64>} : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32>
|
||||||
|
return %0 : tensor<4x128xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_dynamic_slice(
|
// CHECK-LABEL: func @convert_dynamic_slice(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>,
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<i32>,
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<i32>,
|
||||||
|
@ -1198,18 +1198,22 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that start_index_map and collapsed_slice_dims are both an iota
|
// Verify that start_index_map and collapsed_slice_dims contains the same
|
||||||
// with the same number of elements as the last dimension of start_indices.
|
// values.
|
||||||
auto start_index_map = gather_op.dimension_numbers().start_index_map();
|
auto start_index_map = gather_op.dimension_numbers().start_index_map();
|
||||||
auto collapsed_slice_dims =
|
auto collapsed_slice_dims =
|
||||||
gather_op.dimension_numbers().collapsed_slice_dims();
|
gather_op.dimension_numbers().collapsed_slice_dims();
|
||||||
if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) ||
|
if (start_index_map.getNumElements() !=
|
||||||
!IsIotaAttr(collapsed_slice_dims,
|
collapsed_slice_dims.getNumElements()) {
|
||||||
start_indices_type.getShape().back())) {
|
|
||||||
// TODO(tberghammer): Transform start_indices to support non-standard
|
|
||||||
// start_index_maps.
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
gather_op, "unsupported start index map and/or collapsed slice dims");
|
gather_op,
|
||||||
|
"different size for start index map and collapsed slice dims");
|
||||||
|
}
|
||||||
|
for (auto c : collapsed_slice_dims) {
|
||||||
|
if (llvm::count(start_index_map, c) == 0) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
gather_op, "collapsed slice dim isn't present in start index map");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that slice_sizes is 1 for the indexed dimensions and the full
|
// Verify that slice_sizes is 1 for the indexed dimensions and the full
|
||||||
@ -1217,7 +1221,7 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
|
|||||||
auto slice_sizes = gather_op.slice_sizes();
|
auto slice_sizes = gather_op.slice_sizes();
|
||||||
int64_t index = 0;
|
int64_t index = 0;
|
||||||
for (int64_t s : slice_sizes.getValues<int64_t>()) {
|
for (int64_t s : slice_sizes.getValues<int64_t>()) {
|
||||||
if (index < start_indices_type.getShape().back()) {
|
if (llvm::count(start_index_map, index)) {
|
||||||
if (s != 1) {
|
if (s != 1) {
|
||||||
return rewriter.notifyMatchFailure(gather_op,
|
return rewriter.notifyMatchFailure(gather_op,
|
||||||
"unsupported slice sizes");
|
"unsupported slice sizes");
|
||||||
@ -1242,6 +1246,25 @@ class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
|
|||||||
++offset;
|
++offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transpose the operand to handle non-iota start index map.
|
||||||
|
llvm::SmallVector<int64_t, 4> transpose_dimensions;
|
||||||
|
llvm::SmallVector<int64_t, 4> transpose_shape;
|
||||||
|
for (auto s : start_index_map) {
|
||||||
|
transpose_dimensions.push_back(s.getZExtValue());
|
||||||
|
transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]);
|
||||||
|
}
|
||||||
|
for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
|
||||||
|
if (llvm::count(start_index_map, i) == 0) {
|
||||||
|
transpose_dimensions.push_back(i);
|
||||||
|
transpose_shape.push_back(operand_type.getShape()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
operand_type =
|
||||||
|
RankedTensorType::get(transpose_shape, operand_type.getElementType());
|
||||||
|
operand = rewriter.create<mhlo::TransposeOp>(
|
||||||
|
gather_op.getLoc(), operand_type, operand,
|
||||||
|
rewriter.getI64TensorAttr(transpose_dimensions));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
|
rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
|
||||||
start_indices);
|
start_indices);
|
||||||
return success();
|
return success();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user