Fix GetOutputShapeForBroadcastGradientArgs for 1 dim sizes
If the dimension is equal and 1 then index gets added to both r0 and r1. E.g., ```python s0 = [501, 1, 32, 1280] s1 = [ 1, 1, 1, 1280] print(tf.raw_ops.BroadcastGradientArgs(s0=s0, s1=s1)) >> BroadcastGradientArgs( r0=<tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>, r1=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>) ``` PiperOrigin-RevId: 346808589 Change-Id: I871ecd451ed564de196a28e2edee2a5fd2e894da
This commit is contained in:
parent
ec86d80f19
commit
4a227e9fc5
@ -625,6 +625,9 @@ void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
|
||||
r0.push_back(idx);
|
||||
else
|
||||
r1.push_back(idx);
|
||||
} else if (s0_shape[s0_idx] == 1) {
|
||||
r0.push_back(idx);
|
||||
r1.push_back(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -512,17 +512,80 @@ func @DontFoldNoConstantFold() -> tensor<8xf32> {
|
||||
return %2 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs
|
||||
func @testBroadcastGradientArgs() -> (tensor<1xi32>, tensor<0xi32>) {
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs1
|
||||
func @testBroadcastGradientArgs1() -> (tensor<1xi32>, tensor<0xi32>) {
|
||||
%s0 = "tf.Const"() {value = dense<[4]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>)
|
||||
// CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
// CHECK-NOT: tf.BroadcastGradientArgs
|
||||
// CHECK: return %[[R0]], %[[R1]]
|
||||
|
||||
return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs2
|
||||
func @testBroadcastGradientArgs2() -> (tensor<1xi32>, tensor<3xi32>) {
|
||||
%s2 = "tf.Const"() {value = dense<[501, 1, 32, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%s3 = "tf.Const"() {value = dense<[ 1, 1, 1, 1280]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%r2, %r3 = "tf.BroadcastGradientArgs"(%s2, %s3) {} : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<3xi32>)
|
||||
// CHECK-DAG: %[[R2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-DAG: %[[R3:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NOT: tf.BroadcastGradientArgs
|
||||
// CHECK: return %[[R2]], %[[R3]]
|
||||
|
||||
return %r2, %r3 : tensor<1xi32>, tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs3
|
||||
func @testBroadcastGradientArgs3() -> (tensor<3xi32>, tensor<3xi32>) {
|
||||
%s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%s5 = "tf.Const"() {value = dense<[1, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
|
||||
// CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NOT: tf.BroadcastGradientArgs
|
||||
// CHECK: return %[[R0]], %[[R0]]
|
||||
|
||||
return %r4, %r5 : tensor<3xi32>, tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs4
|
||||
func @testBroadcastGradientArgs4() -> (tensor<2xi32>, tensor<3xi32>) {
|
||||
%s4 = "tf.Const"() {value = dense<[1, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%s5 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<3xi32>, tensor<0xi32>) -> (tensor<2xi32>, tensor<3xi32>)
|
||||
// CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NOT: tf.BroadcastGradientArgs
|
||||
// CHECK: return %[[R0]], %[[R1]]
|
||||
|
||||
return %r4, %r5 : tensor<2xi32>, tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs5
|
||||
func @testBroadcastGradientArgs5() -> (tensor<1xi32>, tensor<1xi32>) {
|
||||
%s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%s5 = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>)
|
||||
// CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NOT: tf.BroadcastGradientArgs
|
||||
// CHECK: return %[[R0]], %[[R0]]
|
||||
|
||||
return %r4, %r5 : tensor<1xi32>, tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgs6
|
||||
func @testBroadcastGradientArgs6() -> (tensor<1xi32>, tensor<0xi32>) {
|
||||
%s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%s5 = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<0xi32>)
|
||||
// CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
// CHECK-NOT: tf.BroadcastGradientArgs
|
||||
// CEHCK: return [[R0]], [[R1]]
|
||||
// CHECK: return %[[R0]], %[[R1]]
|
||||
|
||||
return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
|
||||
return %r4, %r5 : tensor<1xi32>, tensor<0xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testBroadcastGradientArgsHigherRank
|
||||
|
Loading…
x
Reference in New Issue
Block a user