[XLA] Constrain slice start and limits to be the same type
Most importantly, ensures these are the same size. PiperOrigin-RevId: 261786941
This commit is contained in:
parent
66ba95a89c
commit
3f0d7f179c
@ -433,8 +433,10 @@ def XLA_CompareOp: XLA_Op<"compare",
|
||||
// XLA Slice definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def XLA_SliceOp: XLA_UnaryElementwiseOp<"slice",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
def XLA_SliceOp: XLA_UnaryElementwiseOp<
|
||||
"slice",
|
||||
[NoSideEffect, SameOperandsAndResultElementType,
|
||||
AllTypesMatch<["start_indices", "limit_indices"]>]> {
|
||||
let arguments = (
|
||||
ins XLA_Tensor:$operand,
|
||||
ElementsAttr:$start_indices,
|
||||
|
@ -394,8 +394,32 @@ func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2:
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @slice
|
||||
func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||
%0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||
// expected-error@+1 {{failed to verify that all of {start_indices, limit_indices} have same type}}
|
||||
%0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
|
||||
// expected-error@+1 {{requires the same element type for all operands and results}}
|
||||
%0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
return %0: tensor<2x1x4x3xi32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user