From 5db729c9d63a224625fd1f396a7c45500145d73a Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 27 May 2020 09:44:03 -0700 Subject: [PATCH] Eliminate tf.IfRegion non-condition inputs The then and else regions can reference their inputs directly without having to wire them through the IfRegion op inputs. This will allow a more direct representation of how these values are used within these regions PiperOrigin-RevId: 313406455 Change-Id: I0756f659c9dec4ef348c38f358bf294b3d004ae3 --- .../compiler/mlir/tensorflow/ir/tf_ops.td | 2 - .../mlir/tensorflow/tests/tf-ops.mlir | 65 +++++++------------ 2 files changed, 25 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 1b8f9eb4bb6..7f31c274a09 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -237,7 +237,6 @@ cond: A Tensor. If the tensor is a scalar of non-boolean type, the True and zero means False; if the scalar is a string, non-empty means True and empty means False. If the tensor is not a scalar, being empty means False and being non-empty means True. -input: A list of input tensors. then_branch: A region that computes the outputs of the op if cond = true. It returns a list of tensors using tf.yield (as the terminator). The types of these returned tensors is same as that of the else_branch @@ -248,7 +247,6 @@ else_branch: A region that computes the outputs of the op if cond = false. let arguments = (ins TF_Tensor:$cond, - Variadic:$input, DefaultValuedAttr:$output_shapes, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index c0d1a914788..2e00dd6a517 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -865,13 +865,14 @@ func @testInvalidYieldOp(%arg0: f32) -> () { // Test valid tf.IfRegion operation // CHECK-LABEL: func @testValidIfRegionOp func @testValidIfRegionOp(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () }, { - %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %e = "tf.Acos"(%neg) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -881,7 +882,7 @@ func @testValidIfRegionOp(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf // Test valid tf.IfRegion operation with multiple results // CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults func @testValidIfRegionOpWithMultipleResults(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - %0, %1, %2 = "tf.IfRegion"(%arg0, %arg1) ({ + %0, %1, %2 = "tf.IfRegion"(%arg0) ({ %t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> %t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> %t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> @@ -891,7 +892,7 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor, %arg1: tensor<2x %e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> %e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + }) { is_stateless = false} : (tensor) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) %3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -903,42 +904,26 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor, %arg1: tensor<2x // Test invalid type for operand #0 for tf.IfRegion operation func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{operand #0 must be tensor of tf.dtype values}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (f32, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (f32) -> tensor<2xf32> return %0 : tensor<2xf32> } // ----- -// Test invalid type for operand #1 for tf.IfRegion operation -func @testInvalidIfRegionOpType1(%arg0: tensor, %arg1: f32) -> f32 { - // expected-error @+1 {{operand #1 must be tensor of tf.dtype values}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ - %t = addf %arg1, %arg1 : f32 - "tf.Yield"(%t) : (f32) -> () - }, { - %e = mulf %arg1, %arg1 : f32 - "tf.Yield"(%e) : (f32) -> () - }) { is_stateless = false} : (tensor, f32) -> f32 - - return %0 : f32 -} - -// ----- - // tf.IfRegion operation should have 2 regions func @testInvalidIfRegionOp1Region(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{op expected 2 regions}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -947,7 +932,7 @@ func @testInvalidIfRegionOp1Region(%arg0: tensor, %arg1: tensor<2xf32>) -> t func @testInvalidIfRegionOpNoRegions(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{op expected 2 regions}} - %0 = "tf.IfRegion"(%arg0, %arg1) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.IfRegion"(%arg0) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -956,7 +941,7 @@ func @testInvalidIfRegionOpNoRegions(%arg0: tensor, %arg1: tensor<2xf32>) -> func @testInvalidIfRegionOp3Regions(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{op expected 2 regions}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () }, { @@ -965,7 +950,7 @@ func @testInvalidIfRegionOp3Regions(%arg0: tensor, %arg1: tensor<2xf32>) -> }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -976,12 +961,12 @@ func @testInvalidIfRegionOp3Regions(%arg0: tensor, %arg1: tensor<2xf32>) -> func @testIfRegionThenTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}} // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -991,12 +976,12 @@ func @testIfRegionThenTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> ten func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}} // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -1006,13 +991,13 @@ func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> ten // tf.Region yield number of results should match op number of results func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{then region should have 1 result}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -1021,13 +1006,13 @@ func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> te func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{else region should have 1 result}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -1037,12 +1022,12 @@ func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> te // tf.IfRegion yield types should match op result types func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{then result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ "tf.Yield"(%arg0) : (tensor) -> () }, { %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%e) : (tensor<2xf32>) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> } @@ -1051,12 +1036,12 @@ func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) - func @testIfRegionOpYieldMismatchElse(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{else result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} - %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () }, { "tf.Yield"(%arg0) : (tensor) -> () - }) { is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> return %0 : tensor<2xf32> }