Eliminate tf.IfRegion non-condition inputs

The then and else regions can reference their inputs directly without having to wire them through the IfRegion op inputs. This will allow a more direct representation of how these values are used within these regions

PiperOrigin-RevId: 313406455
Change-Id: I0756f659c9dec4ef348c38f358bf294b3d004ae3
This commit is contained in:
Rahul Joshi 2020-05-27 09:44:03 -07:00 committed by TensorFlower Gardener
parent b847ff9b30
commit 5db729c9d6
2 changed files with 25 additions and 42 deletions

View File

@ -237,7 +237,6 @@ cond: A Tensor. If the tensor is a scalar of non-boolean type, the
True and zero means False; if the scalar is a string, non-empty
means True and empty means False. If the tensor is not a scalar,
being empty means False and being non-empty means True.
input: A list of input tensors.
then_branch: A region that computes the outputs of the op if cond = true.
It returns a list of tensors using tf.yield (as the terminator). The
types of these returned tensors is same as that of the else_branch
@ -248,7 +247,6 @@ else_branch: A region that computes the outputs of the op if cond = false.
let arguments = (ins
TF_Tensor:$cond,
Variadic<TF_Tensor>:$input,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,

View File

@ -865,13 +865,14 @@ func @testInvalidYieldOp(%arg0: f32) -> () {
// Test valid tf.IfRegion operation
// CHECK-LABEL: func @testValidIfRegionOp
func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%e = "tf.Acos"(%neg) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -881,7 +882,7 @@ func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf
// Test valid tf.IfRegion operation with multiple results
// CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults
func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0, %1, %2 = "tf.IfRegion"(%arg0, %arg1) ({
%0, %1, %2 = "tf.IfRegion"(%arg0) ({
%t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
@ -891,7 +892,7 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2x
%e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
}) { is_stateless = false} : (tensor<i1>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
%3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
%4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
@ -903,42 +904,26 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2x
// Test invalid type for operand #0 for tf.IfRegion operation
func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{operand #0 must be tensor of tf.dtype values}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (f32, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (f32) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Test invalid type for operand #1 for tf.IfRegion operation
func @testInvalidIfRegionOpType1(%arg0: tensor<i1>, %arg1: f32) -> f32 {
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = addf %arg1, %arg1 : f32
"tf.Yield"(%t) : (f32) -> ()
}, {
%e = mulf %arg1, %arg1 : f32
"tf.Yield"(%e) : (f32) -> ()
}) { is_stateless = false} : (tensor<i1>, f32) -> f32
return %0 : f32
}
// -----
// tf.IfRegion operation should have 2 regions
func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{op expected 2 regions}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -947,7 +932,7 @@ func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> t
func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{op expected 2 regions}}
%0 = "tf.IfRegion"(%arg0, %arg1) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
%0 = "tf.IfRegion"(%arg0) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -956,7 +941,7 @@ func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{op expected 2 regions}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
@ -965,7 +950,7 @@ func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -976,12 +961,12 @@ func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -991,12 +976,12 @@ func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> ten
func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -1006,13 +991,13 @@ func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> ten
// tf.Region yield number of results should match op number of results
func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{then region should have 1 result}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -1021,13 +1006,13 @@ func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> te
func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{else region should have 1 result}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -1037,12 +1022,12 @@ func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> te
// tf.IfRegion yield types should match op result types
func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{then result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -1051,12 +1036,12 @@ func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -
func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{else result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}