diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index a49648b0b37..60ee53e17b3 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -296,9 +296,11 @@ StatusOr HloFunctionImporter::ImportInstruction( std::vector slice_sizes( instruction->dynamic_slice_sizes().begin(), instruction->dynamic_slice_sizes().end()); - attributes.push_back( - builder_->getNamedAttr("slice_sizes", Convert(slice_sizes))); - MakeAndReturn(DynamicSliceOp); + return func_builder + ->create( + loc, result_type, operands[0], + makeArrayRef(operands).drop_front(), Convert(slice_sizes)) + .getOperation(); } case HloOpcode::kDynamicUpdateSlice: { return func_builder diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index a60ebd76d0e..37fc831dc3b 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -836,11 +836,64 @@ static LogicalResult Verify(ConcatenateOp op) { // DynamicSliceOp //===----------------------------------------------------------------------===// +namespace { +// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. +// This canonicalization is applied the case when the `begin` input values are +// compile time constants and thus can be made into a tensor. +struct DynamicSliceToSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice, + PatternRewriter& rewriter) const override { + Value input = dynamic_slice.operand(); + auto input_tensor = input.getType().dyn_cast(); + if (!input_tensor) return failure(); + + SmallVector temp_start_indices; + for (Value start : dynamic_slice.start_indices()) { + APInt val; + if (!matchPattern(start, m_ConstantInt(&val))) { + return failure(); + } + temp_start_indices.push_back(*(val.getRawData())); + } + + // At this point we've determined that the start indices are all constants; + // pack them into a single tensor. + auto loc = dynamic_slice.getLoc(); + int64_t input_rank = input_tensor.getRank(); + auto slice_start_indices = + GetI64ElementsAttr(temp_start_indices, &rewriter); + DenseIntElementsAttr slice_limits = BuildSliceLimits( + slice_start_indices, dynamic_slice.slice_sizes(), &rewriter); + DenseIntElementsAttr slice_strides = + GetI64ElementsAttr(SmallVector(input_rank, 1), &rewriter); + auto result = rewriter.create(loc, input, slice_start_indices, + slice_limits, slice_strides); + rewriter.replaceOp(dynamic_slice, {result}); + return success(); + } +}; + +} // namespace + void DynamicSliceOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); } +// Verifies that the number of slice sizes and the number of start indices match +static LogicalResult Verify(DynamicSliceOp op) { + int num_slice_sizes = op.slice_sizes().getNumElements(); + int num_start_indices = op.start_indices().size(); + if (num_start_indices != num_slice_sizes) { + return op.emitOpError() + << "has mismatched number of slice sizes (" << num_slice_sizes + << ") and number of start indices (" << num_start_indices << ")"; + } + return success(); +} + //===----------------------------------------------------------------------===// // InfeedOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index abfc42b20d9..907183b62b7 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -661,11 +661,10 @@ def HLO_SliceOp: HLO_Op< } def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", - [NoSideEffect, AllElementTypesMatch<["operand", "result"]>, - AllShapesMatch<["start_indices", "slice_sizes"]>]> { + [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> { let arguments = (ins HLO_Tensor:$operand, - HLO_Tensor:$start_indices, + Variadic:$start_indices, I64ElementsAttr:$slice_sizes ); diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 1b60745b20c..353673a8448 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -1,9 +1,8 @@ // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure -func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { +func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // CHECK: "xla_hlo.dynamic-slice" - %0 = xla_hlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %1 : tensor<1x4xi32> } @@ -14,21 +13,22 @@ func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} // CHECK: return %[[RESULT]] : tensor<2xi32> - %0 = xla_hlo.constant dense<1> : tensor<1xi64> - %2 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32> - return %2 : tensor<2xi32> + %0 = xla_hlo.constant dense<1> : tensor + %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + return %1 : tensor<2xi32> } // CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape -func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { +func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor<2xi64>) -> tensor<1x4xi32> - return %1 : tensor<1x4xi32> + // CHECK: return %[[RESULT]] : tensor + %0 = xla_hlo.constant dense<1> : tensor + %1 = xla_hlo.constant dense<0> : tensor + %2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor + return %2 : tensor } // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 2b1c9172f70..9c23d5b3332 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2249,7 +2249,16 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32> + // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START:.*]]) : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) + // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> // CHECK: return %[[RESULT]] : tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -2261,7 +2270,12 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32> // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64> - // CHECK: slice_sizes = dense<2> : tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor + // CHECK: "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> @@ -2272,7 +2286,12 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32> + // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<3xi32> // CHECK: return %[[RESULT]] : tensor<3xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -2284,7 +2303,24 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor<2xi64>) -> tensor<1x4xi32> + // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice" + // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) @@ -2295,7 +2331,14 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2 // CHECK-LABEL: slice_variable_start func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index aa38ccd3c30..7735021ea90 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -551,37 +551,45 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // ----- // CHECK-LABEL: func @dynamic_slice -func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> +func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- -func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // expected-error@+1 {{failed to verify that all of {start_indices, slice_sizes} have same shape}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> +func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- // CHECK-LABEL: @dynamic_slice_different_indice_element_type -func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<1xi32>) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<1xi32>) -> tensor<1x4xi32> +func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- -func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xf32> { +func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xf32> { // expected-error@+1 {{failed to verify that all of {operand, result} have same element type}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xf32> + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } // ----- +func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 8953516c5fc..b65381b3a42 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -860,6 +860,21 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { // ----- +// CHECK: HloModule +func @main(%arg: tensor<3x4xi32>, %start1: tensor, %start2: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG:.*]] = s32[3,4] parameter(0) +// CHECK: %[[ARG1:.*]] = s64[] parameter(1) +// CHECK: %[[ARG2:.*]] = s64[] parameter(2) +// CHECK: ROOT +// CHECK-SAME: s32[1,4] dynamic-slice(s32[3,4] %[[ARG]], s64[] %[[ARG1]], s64[] %[[ARG2]]), dynamic_slice_sizes={1,4} + +// ----- + // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { "xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> () diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 89a34dfa68a..27c9c843283 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -347,13 +347,15 @@ add { } // CHECK-LABEL: func @test_dynamic_slice -// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_INDICES:%.*]]: tensor<3xi32> +// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_IDX_1:%.*]]: tensor, [[START_IDX_2:%.*]]: tensor, [[START_IDX_3:%.*]]: tensor %test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] { %operand = s32[2,2,258] parameter(0) - %start_indices = s32[3] parameter(1) - // CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_INDICES]]) + %start_idx_1 = s32[] parameter(1) + %start_idx_2 = s32[] parameter(2) + %start_idx_3 = s32[] parameter(3) + // CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]]) // CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64> - ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[3] %start_indices), dynamic_slice_sizes={1,1,32} + ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32} } // CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor, %arg3: tensor) -> tensor<4x4xf32> { diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index 65f81aae9f2..b788cb80380 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -19,25 +19,6 @@ include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" -//===----------------------------------------------------------------------===// -// DynamicSlice op patterns. -//===----------------------------------------------------------------------===// - -def BuildSliceLimits : NativeCodeCall< - "BuildSliceLimits($0.cast()," - "$1.cast(), &$_builder)">; - -def BuildSliceStrides : NativeCodeCall< - "GetI64ElementsAttr(SmallVector(" - "$0.getType().cast().getRank(), 1), &$_builder)">; - -def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input, - (HLO_ConstOp I64ElementsAttr:$starting_indices), - I64ElementsAttr:$slice_sizes), - (HLO_SliceOp $input, (CastIntElementsAttr $starting_indices), - (BuildSliceLimits $starting_indices, $slice_sizes), - (BuildSliceStrides $input))>; - def UnaryToBinaryEinsumEq : NativeCodeCall< "$_builder.getStringAttr(\",\" + $0.getValue().str())">; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 50536e6a124..c8c26f7f220 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -168,6 +168,20 @@ static ConvertOp CastValueToI64(Location loc, Value value, return rewriter->create(loc, value, rewriter->getIntegerType(64)); } +// Creates an unpack op along the 0th dimension of the tensor. The `value` input +// must be a ranked tensor. +static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, + PatternRewriter *rewriter) { + auto indices_type = value.getType().cast(); + int num_outputs = indices_type.getShape().front(); + SmallVector unpacked_indices_type( + num_outputs, RankedTensorType::get({}, indices_type.getElementType())); + auto unpacked_indices = rewriter->create( + loc, unpacked_indices_type, value, + IntegerAttr::get(rewriter->getIntegerType(64), 0)); + return unpacked_indices; +} + // Returns size of dimension at the specified index, if ranked tensor. // Otherwise, returns -1. // diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 2f825a882f7..bb505b6d3d6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -479,6 +479,9 @@ def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featu // Slice op patterns. //===----------------------------------------------------------------------===// +def CastToI64AndUnpackTensor: NativeCodeCall< + "UnpackTensorAlongZeroDim($0.getLoc(), CastValueToI64($0.getLoc(), $1, &$_builder), &$_builder).output()">; + def CanBeTranslatedToDynamicSlice : Constraint())">>; @@ -488,7 +491,8 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall< def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, (TF_ConstOp $slice_sizes)), - (HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices), + (HLO_DynamicSliceOp $input, + (CastToI64AndUnpackTensor $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), [(CanBeTranslatedToDynamicSlice $input, $starting_indices, $slice_sizes)]>;