Eliminate output_types from If/IfRegion ODS specs
- Also eliminate output_types attribute from several test cases - This attribute may still be present on these ops since the importer seems to generate them. - Added a test to verify that values generated on one branch of the if cannot be consumed on the other branch PiperOrigin-RevId: 313668390 Change-Id: I97bed79f52f6694ead1931a64c411686067d2800
This commit is contained in:
parent
bfd14e0aa3
commit
60c828a70e
@ -292,7 +292,7 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten
|
||||
// CHECK: [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
// CHECK: [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor<?xi32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, output_shapes = [], then_branch = @cond_true} : (tensor<i1>, tensor<3x10xf32>, tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<?x10xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<3x10xf32>, tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<?x10xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<?x10xf32>
|
||||
}
|
||||
|
||||
|
@ -577,7 +577,6 @@ struct ConvertTensorListResize
|
||||
ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
|
||||
/*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
|
||||
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
|
||||
/*output_shapes=*/rewriter.getArrayAttr({}),
|
||||
/*is_stateless=*/rewriter.getBoolAttr(true));
|
||||
return success();
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ namespace TF {
|
||||
|
||||
namespace AttrKind {
|
||||
|
||||
// List of supported custom TensorFlow Attributes kinds, necessary for
|
||||
// List of supported custom TensorFlow Attribute kinds, necessary for
|
||||
// isa/dyn_cast.
|
||||
enum Kind {
|
||||
FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR,
|
||||
|
@ -188,7 +188,6 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
|
||||
FlatSymbolRefAttr:$then_branch,
|
||||
FlatSymbolRefAttr:$else_branch,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
|
||||
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
|
||||
BoolAttr:$is_stateless
|
||||
@ -248,8 +247,6 @@ else_branch: A region that computes the outputs of the op if cond = false.
|
||||
let arguments = (ins
|
||||
TF_Tensor:$cond,
|
||||
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
|
||||
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
|
||||
BoolAttr:$is_stateless
|
||||
);
|
||||
|
@ -145,8 +145,8 @@ func @main(%arg0: tensor<i1>) -> tensor<2xf32> attributes {tf.entry_function = {
|
||||
%2 = "tf.ReadVariableOp"(%1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%3 = "tf.Less"(%2, %0) : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%4 = "tf.If"(%3, %1, %2) {Tcond = i1, Tin = ["tfdtype$DT_RESOURCE", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"],
|
||||
else_branch = @cond_false, is_stateless = false, output_shapes = [#tf.shape<>],
|
||||
then_branch = @cond_true} : (tensor<i1>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
else_branch = @cond_false, is_stateless = false,then_branch = @cond_true} :
|
||||
(tensor<i1>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
%5 = "tf.Identity"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
%6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
|
||||
return %6 : tensor<2xf32>
|
||||
|
@ -217,7 +217,7 @@ func @error_on_conflict_multiple_callers(
|
||||
// expected-error@above {{Conflicting device assignment for resource}}
|
||||
then_branch = @if_then_and_else,
|
||||
else_branch = @if_then_and_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
|
||||
tf_executor.yield
|
||||
|
@ -420,7 +420,7 @@ func @cluster_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]])
|
||||
%3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<4>], is_stateless = false}
|
||||
is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0)
|
||||
@ -468,7 +468,7 @@ func @cluster_with_nested_if(%arg0: tensor<i1>) -> tensor<f32> {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]])
|
||||
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]])
|
||||
@ -488,7 +488,7 @@ func @if_then(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.re
|
||||
// CHECK-NEXT: %[[IIF:.*]] = "tf.If"(%[[TARG0]], %[[TARG0]])
|
||||
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%3 = "tf.If"(%read, %arg0) {then_branch = @inner_if_then, else_branch = @inner_if_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
is_stateless = false}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK-NEXT: return %[[IIF]]
|
||||
@ -526,7 +526,7 @@ func @cluster_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}}
|
||||
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
output_shapes = [#tf.shape<>], is_stateless = false}
|
||||
is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
%4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
|
@ -102,7 +102,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: func @shape_from_if_to_branch_functions
|
||||
func @shape_from_if_to_branch_functions(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
@ -185,9 +185,9 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
// CHECK-LABEL: func @invalid_function_reused_by_control_flows
|
||||
func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
// expected-warning @+1 {{unable to refine shape}}
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
// expected-warning @+1 {{unable to refine shape}}
|
||||
%1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
%1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
|
@ -1048,6 +1048,36 @@ func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -
|
||||
|
||||
// -----
|
||||
|
||||
// value generated in one branch cannot be consumed in the other branch
|
||||
func @testIfRegionElseConsumingThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
// expected-error @+1 {{use of undeclared SSA value name}}
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testIfRegionThenConsumingElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
// expected-error @+1 {{does not dominate this use}}
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
// expected-note @+1 {{operand defined here}}
|
||||
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOp
|
||||
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
|
@ -700,15 +700,10 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
// Erase the resource outputs from the branches.
|
||||
int64_t non_resource_results = 0;
|
||||
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
|
||||
llvm::SmallVector<Attribute, 4> new_output_shapes;
|
||||
bool output_removed = false;
|
||||
for (auto result : if_op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
|
||||
old_to_new_output_indices.push_back(non_resource_results++);
|
||||
if (!if_op.output_shapes().getValue().empty()) {
|
||||
new_output_shapes.push_back(
|
||||
if_op.output_shapes().getValue()[result.getResultNumber()]);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
old_to_new_output_indices.push_back(-1);
|
||||
@ -781,8 +776,7 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
|
||||
then_branch.getType().getResults(),
|
||||
new_operands, if_op.getAttrs());
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp() and update
|
||||
// new_output_shapes.
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp()
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
for (const auto& entry : remaining_resource_data_types) {
|
||||
@ -792,14 +786,9 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
: new_output_it->getSecond();
|
||||
arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
|
||||
entry.getSecond(), update_index};
|
||||
if (!if_op.output_shapes().getValue().empty() && update_index >= 0) {
|
||||
new_output_shapes.push_back(
|
||||
tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond()));
|
||||
}
|
||||
}
|
||||
AddLoadsStoresOutsideControlFlowOp(new_if,
|
||||
arg_data_type_and_updated_output_index);
|
||||
new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
|
||||
if (old_to_new_output_indices[i] >= 0) {
|
||||
|
@ -254,22 +254,14 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
|
||||
if (output_buffer_to_size.empty() && arg_no_changed) return success();
|
||||
// Recreate the If op.
|
||||
auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
|
||||
auto new_output_shapes = llvm::to_vector<8>(if_op.output_shapes().getValue());
|
||||
for (int64_t i = 1; i < if_op.getNumOperands(); ++i) {
|
||||
auto it = buffer_to_size->find(if_op.getOperand(i));
|
||||
if (it == buffer_to_size->end()) continue;
|
||||
new_if_operands.push_back(it->getSecond().size);
|
||||
if (!new_output_shapes.empty()) {
|
||||
// Size is a scalar shape.
|
||||
tensorflow::TensorShapeProto shape_proto;
|
||||
new_output_shapes.push_back(builder.getStringAttr(
|
||||
tensorflow::mangling_util::MangleShape(shape_proto)));
|
||||
}
|
||||
}
|
||||
auto new_if = OpBuilder(if_op).create<TF::IfOp>(
|
||||
if_op.getLoc(), then_branch.getType().getResults(), new_if_operands,
|
||||
if_op.getAttrs());
|
||||
new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
(*buffer_to_size)[new_if.getResult(std::get<0>(entry))] = {
|
||||
new_if.getResult(std::get<1>(entry)), std::get<2>(entry)};
|
||||
|
Loading…
Reference in New Issue
Block a user