Fix GetOutputShapeForBroadcastGradientArgs for 1 dim sizes

If the dimension is equal and 1 then index gets added to both r0 and r1. E.g.,

```python
s0 = [501, 1, 32, 1280]
s1 = [  1, 1,  1, 1280]
print(tf.raw_ops.BroadcastGradientArgs(s0=s0, s1=s1))
>> BroadcastGradientArgs(
    r0=<tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>,
    r1=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>)
```
PiperOrigin-RevId: 346808589
Change-Id: I871ecd451ed564de196a28e2edee2a5fd2e894da
This commit is contained in:
Jacques Pienaar 2020-12-10 10:05:25 -08:00 committed by TensorFlower Gardener
parent ec86d80f19
commit 4a227e9fc5
2 changed files with 70 additions and 4 deletions

View File

@ -625,6 +625,9 @@ void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
r0.push_back(idx);
else
r1.push_back(idx);
} else if (s0_shape[s0_idx] == 1) {
r0.push_back(idx);
r1.push_back(idx);
}
}
}

View File

@ -512,17 +512,80 @@ func @DontFoldNoConstantFold() -> tensor<8xf32> {
return %2 : tensor<8xf32>
}
// CHECK-LABEL: func @testBroadcastGradientArgs
func @testBroadcastGradientArgs() -> (tensor<1xi32>, tensor<0xi32>) {
// CHECK-LABEL: func @testBroadcastGradientArgs1
func @testBroadcastGradientArgs1() -> (tensor<1xi32>, tensor<0xi32>) {
%s0 = "tf.Const"() {value = dense<[4]> : tensor<1xi32>} : () -> tensor<1xi32>
%s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
%r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>)
// CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK-NOT: tf.BroadcastGradientArgs
// CHECK: return %[[R0]], %[[R1]]
return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
}
// CHECK-LABEL: func @testBroadcastGradientArgs2
func @testBroadcastGradientArgs2() -> (tensor<1xi32>, tensor<3xi32>) {
%s2 = "tf.Const"() {value = dense<[501, 1, 32, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
%s3 = "tf.Const"() {value = dense<[ 1, 1, 1, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
%r2, %r3 = "tf.BroadcastGradientArgs"(%s2, %s3) {} : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<3xi32>)
// CHECK-DAG: %[[R2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-DAG: %[[R3:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NOT: tf.BroadcastGradientArgs
// CHECK: return %[[R2]], %[[R3]]
return %r2, %r3 : tensor<1xi32>, tensor<3xi32>
}
// CHECK-LABEL: func @testBroadcastGradientArgs3
func @testBroadcastGradientArgs3() -> (tensor<3xi32>, tensor<3xi32>) {
%s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%s5 = "tf.Const"() {value = dense<[1, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
// CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NOT: tf.BroadcastGradientArgs
// CHECK: return %[[R0]], %[[R0]]
return %r4, %r5 : tensor<3xi32>, tensor<3xi32>
}
// CHECK-LABEL: func @testBroadcastGradientArgs4
func @testBroadcastGradientArgs4() -> (tensor<2xi32>, tensor<3xi32>) {
%s4 = "tf.Const"() {value = dense<[1, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%s5 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<3xi32>, tensor<0xi32>) -> (tensor<2xi32>, tensor<3xi32>)
// CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NOT: tf.BroadcastGradientArgs
// CHECK: return %[[R0]], %[[R1]]
return %r4, %r5 : tensor<2xi32>, tensor<3xi32>
}
// CHECK-LABEL: func @testBroadcastGradientArgs5
func @testBroadcastGradientArgs5() -> (tensor<1xi32>, tensor<1xi32>) {
%s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%s5 = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>)
// CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NOT: tf.BroadcastGradientArgs
// CHECK: return %[[R0]], %[[R0]]
return %r4, %r5 : tensor<1xi32>, tensor<1xi32>
}
// CHECK-LABEL: func @testBroadcastGradientArgs6
func @testBroadcastGradientArgs6() -> (tensor<1xi32>, tensor<0xi32>) {
%s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%s5 = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32>
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<0xi32>)
// CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK-NOT: tf.BroadcastGradientArgs
// CEHCK: return [[R0]], [[R1]]
// CHECK: return %[[R0]], %[[R1]]
return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
return %r4, %r5 : tensor<1xi32>, tensor<0xi32>
}
// CHECK-LABEL: func @testBroadcastGradientArgsHigherRank