From 4a227e9fc5f3b5fc4bbcf96aa322d050375e21d0 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 10 Dec 2020 10:05:25 -0800 Subject: [PATCH] Fix GetOutputShapeForBroadcastGradientArgs for 1 dim sizes If the dimension is equal and 1 then index gets added to both r0 and r1. E.g., ```python s0 = [501, 1, 32, 1280] s1 = [ 1, 1, 1, 1280] print(tf.raw_ops.BroadcastGradientArgs(s0=s0, s1=s1)) >> BroadcastGradientArgs( r0=, r1=) ``` PiperOrigin-RevId: 346808589 Change-Id: I871ecd451ed564de196a28e2edee2a5fd2e894da --- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 3 + .../mlir/tensorflow/tests/constant-fold.mlir | 71 +++++++++++++++++-- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index e13f8e04756..95b3f962d6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -625,6 +625,9 @@ void GetOutputShapeForBroadcastGradientArgs(ArrayRef bcasted_shape, r0.push_back(idx); else r1.push_back(idx); + } else if (s0_shape[s0_idx] == 1) { + r0.push_back(idx); + r1.push_back(idx); } } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 528d26c47e7..2859b1ddcb2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -512,17 +512,80 @@ func @DontFoldNoConstantFold() -> tensor<8xf32> { return %2 : tensor<8xf32> } -// CHECK-LABEL: func @testBroadcastGradientArgs -func @testBroadcastGradientArgs() -> (tensor<1xi32>, tensor<0xi32>) { +// CHECK-LABEL: func @testBroadcastGradientArgs1 +func @testBroadcastGradientArgs1() -> (tensor<1xi32>, tensor<0xi32>) { %s0 = "tf.Const"() {value = dense<[4]> : tensor<1xi32>} : () -> tensor<1xi32> %s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32> %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) {} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>) + // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK-NOT: tf.BroadcastGradientArgs + // CHECK: return %[[R0]], %[[R1]] + + return %r0, %r1 : tensor<1xi32>, tensor<0xi32> +} + +// CHECK-LABEL: func @testBroadcastGradientArgs2 +func @testBroadcastGradientArgs2() -> (tensor<1xi32>, tensor<3xi32>) { + %s2 = "tf.Const"() {value = dense<[501, 1, 32, 1280]> : tensor<4xi32>} : () -> tensor<4xi32> + %s3 = "tf.Const"() {value = dense<[ 1, 1, 1, 1280]> : tensor<4xi32>} : () -> tensor<4xi32> + %r2, %r3 = "tf.BroadcastGradientArgs"(%s2, %s3) {} : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<3xi32>) + // CHECK-DAG: %[[R2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-DAG: %[[R3:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NOT: tf.BroadcastGradientArgs + // CHECK: return %[[R2]], %[[R3]] + + return %r2, %r3 : tensor<1xi32>, tensor<3xi32> +} + +// CHECK-LABEL: func @testBroadcastGradientArgs3 +func @testBroadcastGradientArgs3() -> (tensor<3xi32>, tensor<3xi32>) { + %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %s5 = "tf.Const"() {value = dense<[1, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NOT: tf.BroadcastGradientArgs + // CHECK: return %[[R0]], %[[R0]] + + return %r4, %r5 : tensor<3xi32>, tensor<3xi32> +} + +// CHECK-LABEL: func @testBroadcastGradientArgs4 +func @testBroadcastGradientArgs4() -> (tensor<2xi32>, tensor<3xi32>) { + %s4 = "tf.Const"() {value = dense<[1, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %s5 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<3xi32>, tensor<0xi32>) -> (tensor<2xi32>, tensor<3xi32>) + // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NOT: tf.BroadcastGradientArgs + // CHECK: return %[[R0]], %[[R1]] + + return %r4, %r5 : tensor<2xi32>, tensor<3xi32> +} + +// CHECK-LABEL: func @testBroadcastGradientArgs5 +func @testBroadcastGradientArgs5() -> (tensor<1xi32>, tensor<1xi32>) { + %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %s5 = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32> + %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>) + // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NOT: tf.BroadcastGradientArgs + // CHECK: return %[[R0]], %[[R0]] + + return %r4, %r5 : tensor<1xi32>, tensor<1xi32> +} + +// CHECK-LABEL: func @testBroadcastGradientArgs6 +func @testBroadcastGradientArgs6() -> (tensor<1xi32>, tensor<0xi32>) { + %s4 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %s5 = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32> + %r4, %r5 = "tf.BroadcastGradientArgs"(%s4, %s5) {} : (tensor<0xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<0xi32>) // CHECK: %[[R0:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> // CHECK: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NOT: tf.BroadcastGradientArgs - // CEHCK: return [[R0]], [[R1]] + // CHECK: return %[[R0]], %[[R1]] - return %r0, %r1 : tensor<1xi32>, tensor<0xi32> + return %r4, %r5 : tensor<1xi32>, tensor<0xi32> } // CHECK-LABEL: func @testBroadcastGradientArgsHigherRank