Eliminate tf.IfRegion non-condition inputs
The then and else regions can reference their inputs directly without having to wire them through the IfRegion op inputs. This will allow a more direct representation of how these values are used within these regions PiperOrigin-RevId: 313406455 Change-Id: I0756f659c9dec4ef348c38f358bf294b3d004ae3
This commit is contained in:
parent
b847ff9b30
commit
5db729c9d6
@ -237,7 +237,6 @@ cond: A Tensor. If the tensor is a scalar of non-boolean type, the
|
||||
True and zero means False; if the scalar is a string, non-empty
|
||||
means True and empty means False. If the tensor is not a scalar,
|
||||
being empty means False and being non-empty means True.
|
||||
input: A list of input tensors.
|
||||
then_branch: A region that computes the outputs of the op if cond = true.
|
||||
It returns a list of tensors using tf.yield (as the terminator). The
|
||||
types of these returned tensors is same as that of the else_branch
|
||||
@ -248,7 +247,6 @@ else_branch: A region that computes the outputs of the op if cond = false.
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$cond,
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
|
||||
|
||||
@ -865,13 +865,14 @@ func @testInvalidYieldOp(%arg0: f32) -> () {
|
||||
// Test valid tf.IfRegion operation
|
||||
// CHECK-LABEL: func @testValidIfRegionOp
|
||||
func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%e = "tf.Acos"(%neg) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -881,7 +882,7 @@ func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf
|
||||
// Test valid tf.IfRegion operation with multiple results
|
||||
// CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults
|
||||
func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0, %1, %2 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0, %1, %2 = "tf.IfRegion"(%arg0) ({
|
||||
%t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -891,7 +892,7 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2x
|
||||
%e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
|
||||
}) { is_stateless = false} : (tensor<i1>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
|
||||
|
||||
%3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -903,42 +904,26 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2x
|
||||
// Test invalid type for operand #0 for tf.IfRegion operation
|
||||
func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{operand #0 must be tensor of tf.dtype values}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (f32, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (f32) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid type for operand #1 for tf.IfRegion operation
|
||||
func @testInvalidIfRegionOpType1(%arg0: tensor<i1>, %arg1: f32) -> f32 {
|
||||
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = addf %arg1, %arg1 : f32
|
||||
"tf.Yield"(%t) : (f32) -> ()
|
||||
}, {
|
||||
%e = mulf %arg1, %arg1 : f32
|
||||
"tf.Yield"(%e) : (f32) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, f32) -> f32
|
||||
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.IfRegion operation should have 2 regions
|
||||
func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{op expected 2 regions}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -947,7 +932,7 @@ func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> t
|
||||
|
||||
func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{op expected 2 regions}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.IfRegion"(%arg0) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -956,7 +941,7 @@ func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
|
||||
|
||||
func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{op expected 2 regions}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
@ -965,7 +950,7 @@ func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -976,12 +961,12 @@ func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) ->
|
||||
func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
|
||||
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -991,12 +976,12 @@ func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> ten
|
||||
func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
|
||||
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -1006,13 +991,13 @@ func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> ten
|
||||
// tf.Region yield number of results should match op number of results
|
||||
func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{then region should have 1 result}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -1021,13 +1006,13 @@ func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> te
|
||||
|
||||
func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{else region should have 1 result}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -1037,12 +1022,12 @@ func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> te
|
||||
// tf.IfRegion yield types should match op result types
|
||||
func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{then result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
@ -1051,12 +1036,12 @@ func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -
|
||||
|
||||
func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{else result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user