diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.td b/tensorflow/compiler/mlir/xla/ir/xla_ops.td index 08e1a3d8ff6..96444b4d7f5 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/xla_ops.td @@ -433,8 +433,10 @@ def XLA_CompareOp: XLA_Op<"compare", // XLA Slice definitions. //===----------------------------------------------------------------------===// -def XLA_SliceOp: XLA_UnaryElementwiseOp<"slice", - [NoSideEffect, SameOperandsAndResultElementType]> { +def XLA_SliceOp: XLA_UnaryElementwiseOp< + "slice", + [NoSideEffect, SameOperandsAndResultElementType, + AllTypesMatch<["start_indices", "limit_indices"]>]> { let arguments = ( ins XLA_Tensor:$operand, ElementsAttr:$start_indices, diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 854b0e7456a..11dd3db607b 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -394,8 +394,32 @@ func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: // ----- +// CHECK-LABEL: func @slice +func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { + // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices} have same type}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { + // expected-error@+1 {{requires the same element type for all operands and results}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- + // CHECK-LABEL: func @transpose -func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { +func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> }