[XLA] Constrain slice start and limits to be the same type
Most importantly, ensures these are the same size. PiperOrigin-RevId: 261786941
This commit is contained in:
parent
66ba95a89c
commit
3f0d7f179c
@ -433,8 +433,10 @@ def XLA_CompareOp: XLA_Op<"compare",
|
|||||||
// XLA Slice definitions.
|
// XLA Slice definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def XLA_SliceOp: XLA_UnaryElementwiseOp<"slice",
|
def XLA_SliceOp: XLA_UnaryElementwiseOp<
|
||||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
"slice",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType,
|
||||||
|
AllTypesMatch<["start_indices", "limit_indices"]>]> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins XLA_Tensor:$operand,
|
ins XLA_Tensor:$operand,
|
||||||
ElementsAttr:$start_indices,
|
ElementsAttr:$start_indices,
|
||||||
|
@ -394,8 +394,32 @@ func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2:
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @slice
|
||||||
|
func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||||
|
%0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||||
|
return %0 : tensor<1x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||||
|
// expected-error@+1 {{failed to verify that all of {start_indices, limit_indices} have same type}}
|
||||||
|
%0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||||
|
return %0 : tensor<1x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
|
||||||
|
// expected-error@+1 {{requires the same element type for all operands and results}}
|
||||||
|
%0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32>
|
||||||
|
return %0 : tensor<1x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @transpose
|
// CHECK-LABEL: func @transpose
|
||||||
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||||
return %0: tensor<2x1x4x3xi32>
|
return %0: tensor<2x1x4x3xi32>
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user