Constrain HLO DynamicUpdateSlice op start_indices input.
The variadic start_indices input must consist of zero-dimension integer tensors. This constrains the op definition to enforce this. PiperOrigin-RevId: 306919159 Change-Id: I218f2632b5091c3699e7c7de4f70b5887e682f8e
This commit is contained in:
parent
d13a09ad56
commit
cfe262c8cb
@ -678,7 +678,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
|
|||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
HLO_Tensor:$update,
|
HLO_Tensor:$update,
|
||||||
Variadic<HLO_ScalarIntTensor>:$start_indices
|
Variadic<HLO_Tensor>:$start_indices
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor:$result);
|
let results = (outs HLO_Tensor:$result);
|
||||||
|
@ -590,22 +590,6 @@ func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>)
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @dynamic_update_slice
|
|
||||||
func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<3x4xi64> {
|
|
||||||
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor<i64>, tensor<i64>) -> tensor<3x4xi64>
|
|
||||||
return %0 : tensor<3x4xi64>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
|
|
||||||
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}}
|
|
||||||
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
|
|
||||||
return %0 : tensor<3x4xi64>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @transpose
|
// CHECK-LABEL: func @transpose
|
||||||
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||||
|
@ -358,24 +358,24 @@ add {
|
|||||||
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
|
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32> {
|
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {
|
||||||
%test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
|
%test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
|
||||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||||
%Arg_1.2 = f32[1, 4] parameter(1)
|
%Arg_1.2 = f32[1, 4] parameter(1)
|
||||||
%Arg_2.3 = s32[] parameter(2)
|
%Arg_2.3 = f32[] parameter(2)
|
||||||
%Arg_3.4 = s32[] parameter(3)
|
%Arg_3.4 = f32[] parameter(3)
|
||||||
|
|
||||||
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32>
|
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<f32>, tensor<f32>) -> tensor<4x4xf32>
|
||||||
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
|
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<i32>) -> tensor<4xf32>
|
// CHECK-LABEL: func @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<f32>) -> tensor<4xf32>
|
||||||
%test_dynamic_update_slice_2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] {
|
%test_dynamic_update_slice_2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] {
|
||||||
%Arg_0.1 = f32[4] parameter(0)
|
%Arg_0.1 = f32[4] parameter(0)
|
||||||
%Arg_1.2 = f32[2] parameter(1)
|
%Arg_1.2 = f32[2] parameter(1)
|
||||||
%Arg_2.3 = s32[] parameter(2)
|
%Arg_2.3 = f32[] parameter(2)
|
||||||
|
|
||||||
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
|
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<f32>) -> tensor<4xf32>
|
||||||
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
|
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user