Constrain HLO DynamicUpdateSlice op start_indices input.

The variadic start_indices input must consist of zero-dimension integer tensors. This constrains the op definition to enforce this.

PiperOrigin-RevId: 306919159
Change-Id: I218f2632b5091c3699e7c7de4f70b5887e682f8e
This commit is contained in:
A. Unique TensorFlower 2020-04-16 14:14:59 -07:00 committed by TensorFlower Gardener
parent d13a09ad56
commit cfe262c8cb
3 changed files with 8 additions and 24 deletions

View File

@ -678,7 +678,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice",
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
HLO_Tensor:$update, HLO_Tensor:$update,
Variadic<HLO_ScalarIntTensor>:$start_indices Variadic<HLO_Tensor>:$start_indices
); );
let results = (outs HLO_Tensor:$result); let results = (outs HLO_Tensor:$result);

View File

@ -590,22 +590,6 @@ func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>)
// ----- // -----
// CHECK-LABEL: @dynamic_update_slice
func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<3x4xi64> {
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor<i64>, tensor<i64>) -> tensor<3x4xi64>
return %0 : tensor<3x4xi64>
}
// -----
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}}
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
return %0 : tensor<3x4xi64>
}
// -----
// CHECK-LABEL: func @transpose // CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>

View File

@ -358,24 +358,24 @@ add {
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32} ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
} }
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32> { // CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {
%test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] { %test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0) %Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[1, 4] parameter(1) %Arg_1.2 = f32[1, 4] parameter(1)
%Arg_2.3 = s32[] parameter(2) %Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = s32[] parameter(3) %Arg_3.4 = f32[] parameter(3)
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32> // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<f32>, tensor<f32>) -> tensor<4x4xf32>
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4) ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
} }
// CHECK-LABEL: func @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<i32>) -> tensor<4xf32> // CHECK-LABEL: func @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<f32>) -> tensor<4xf32>
%test_dynamic_update_slice_2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] { %test_dynamic_update_slice_2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0) %Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[2] parameter(1) %Arg_1.2 = f32[2] parameter(1)
%Arg_2.3 = s32[] parameter(2) %Arg_2.3 = f32[] parameter(2)
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32> // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<f32>) -> tensor<4xf32>
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3) ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
} }