Eliminate output_types from If/IfRegion ODS specs

- Also eliminate output_types attribute from several test cases
- This attribute may still be present on these ops since the importer seems to
  generate them.
- Added a test to verify that values generated on one branch of the if cannot be
  consumed on the other branch

PiperOrigin-RevId: 313668390
Change-Id: I97bed79f52f6694ead1931a64c411686067d2800
This commit is contained in:
Rahul Joshi 2020-05-28 15:12:28 -07:00 committed by TensorFlower Gardener
parent bfd14e0aa3
commit 60c828a70e
11 changed files with 47 additions and 40 deletions

View File

@ -292,7 +292,7 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten
// CHECK: [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor<i32>, tensor<i32>) -> tensor<i32> // CHECK: [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor<i32>, tensor<i32>) -> tensor<i1> // CHECK: [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK: [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor<?xi32> // CHECK: [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor<?xi32>
// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, output_shapes = [], then_branch = @cond_true} : (tensor<i1>, tensor<3x10xf32>, tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<?x10xf32> // CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<3x10xf32>, tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<?x10xf32>
// CHECK: return [[RESULT]] : tensor<?x10xf32> // CHECK: return [[RESULT]] : tensor<?x10xf32>
} }

View File

@ -577,7 +577,6 @@ struct ConvertTensorListResize
ArrayRef<Value>({input_handle, input_shape, size_diff, size}), ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
/*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op), /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op), /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
/*output_shapes=*/rewriter.getArrayAttr({}),
/*is_stateless=*/rewriter.getBoolAttr(true)); /*is_stateless=*/rewriter.getBoolAttr(true));
return success(); return success();
} }

View File

@ -26,7 +26,7 @@ namespace TF {
namespace AttrKind { namespace AttrKind {
// List of supported custom TensorFlow Attributes kinds, necessary for // List of supported custom TensorFlow Attribute kinds, necessary for
// isa/dyn_cast. // isa/dyn_cast.
enum Kind { enum Kind {
FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR,

View File

@ -188,7 +188,6 @@ else_branch: A function that takes 'inputs' and returns a list of
FlatSymbolRefAttr:$then_branch, FlatSymbolRefAttr:$then_branch,
FlatSymbolRefAttr:$else_branch, FlatSymbolRefAttr:$else_branch,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
// Used to map StatelessIf and If op defined in TensorFlow to a common op. // Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless BoolAttr:$is_stateless
@ -248,8 +247,6 @@ else_branch: A region that computes the outputs of the op if cond = false.
let arguments = (ins let arguments = (ins
TF_Tensor:$cond, TF_Tensor:$cond,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
// Used to map StatelessIf and If op defined in TensorFlow to a common op. // Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless BoolAttr:$is_stateless
); );

View File

@ -145,8 +145,8 @@ func @main(%arg0: tensor<i1>) -> tensor<2xf32> attributes {tf.entry_function = {
%2 = "tf.ReadVariableOp"(%1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32> %2 = "tf.ReadVariableOp"(%1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%3 = "tf.Less"(%2, %0) : (tensor<f32>, tensor<f32>) -> tensor<i1> %3 = "tf.Less"(%2, %0) : (tensor<f32>, tensor<f32>) -> tensor<i1>
%4 = "tf.If"(%3, %1, %2) {Tcond = i1, Tin = ["tfdtype$DT_RESOURCE", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], %4 = "tf.If"(%3, %1, %2) {Tcond = i1, Tin = ["tfdtype$DT_RESOURCE", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"],
else_branch = @cond_false, is_stateless = false, output_shapes = [#tf.shape<>], else_branch = @cond_false, is_stateless = false,then_branch = @cond_true} :
then_branch = @cond_true} : (tensor<i1>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32> (tensor<i1>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
%5 = "tf.Identity"(%4) : (tensor<f32>) -> tensor<f32> %5 = "tf.Identity"(%4) : (tensor<f32>) -> tensor<f32>
%6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<2xf32> %6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
return %6 : tensor<2xf32> return %6 : tensor<2xf32>

View File

@ -217,7 +217,7 @@ func @error_on_conflict_multiple_callers(
// expected-error@above {{Conflicting device assignment for resource}} // expected-error@above {{Conflicting device assignment for resource}}
then_branch = @if_then_and_else, then_branch = @if_then_and_else,
else_branch = @if_then_and_else, else_branch = @if_then_and_else,
output_shapes = [], is_stateless = false} is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>, : (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) -> () tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
tf_executor.yield tf_executor.yield

View File

@ -420,7 +420,7 @@ func @cluster_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
%2 = "tf_device.cluster"() ( { %2 = "tf_device.cluster"() ( {
// CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]]) // CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]])
%3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, %3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
output_shapes = [#tf.shape<>, #tf.shape<4>], is_stateless = false} is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>) : (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0)
@ -468,7 +468,7 @@ func @cluster_with_nested_if(%arg0: tensor<i1>) -> tensor<f32> {
%2 = "tf_device.cluster"() ( { %2 = "tf_device.cluster"() ( {
// CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]]) // CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]])
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
output_shapes = [], is_stateless = false} is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) : (tensor<i1>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]]) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]])
@ -488,7 +488,7 @@ func @if_then(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.re
// CHECK-NEXT: %[[IIF:.*]] = "tf.If"(%[[TARG0]], %[[TARG0]]) // CHECK-NEXT: %[[IIF:.*]] = "tf.If"(%[[TARG0]], %[[TARG0]])
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%3 = "tf.If"(%read, %arg0) {then_branch = @inner_if_then, else_branch = @inner_if_else, %3 = "tf.If"(%read, %arg0) {then_branch = @inner_if_then, else_branch = @inner_if_else,
output_shapes = [], is_stateless = false} is_stateless = false}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>) : (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
// CHECK-NEXT: return %[[IIF]] // CHECK-NEXT: return %[[IIF]]
@ -526,7 +526,7 @@ func @cluster_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
%2 = "tf_device.cluster"() ( { %2 = "tf_device.cluster"() ( {
// expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}} // expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}}
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
output_shapes = [#tf.shape<>], is_stateless = false} is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>) : (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>) -> (tensor<*x!tf.resource<tensor<4xf32>>>)
%4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32> %4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>

View File

@ -102,7 +102,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @shape_from_if_to_branch_functions // CHECK-LABEL: func @shape_from_if_to_branch_functions
func @shape_from_if_to_branch_functions(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { func @shape_from_if_to_branch_functions(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32>
} }
@ -185,9 +185,9 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @invalid_function_reused_by_control_flows // CHECK-LABEL: func @invalid_function_reused_by_control_flows
func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
// expected-warning @+1 {{unable to refine shape}} // expected-warning @+1 {{unable to refine shape}}
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
// expected-warning @+1 {{unable to refine shape}} // expected-warning @+1 {{unable to refine shape}}
%1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32>
} }

View File

@ -1048,6 +1048,36 @@ func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -
// ----- // -----
// value generated in one branch cannot be consumed in the other branch
func @testIfRegionElseConsumingThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.IfRegion"(%arg0) ({
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
// expected-error @+1 {{use of undeclared SSA value name}}
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @testIfRegionThenConsumingElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.IfRegion"(%arg0) ({
// expected-error @+1 {{does not dominate this use}}
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
// expected-note @+1 {{operand defined here}}
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Test valid tf.MatrixBandPart // Test valid tf.MatrixBandPart
// CHECK-LABEL: func @testValidMatrixBandPartOp // CHECK-LABEL: func @testValidMatrixBandPartOp
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {

View File

@ -700,15 +700,10 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
// Erase the resource outputs from the branches. // Erase the resource outputs from the branches.
int64_t non_resource_results = 0; int64_t non_resource_results = 0;
llvm::SmallVector<int64_t, 4> old_to_new_output_indices; llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
llvm::SmallVector<Attribute, 4> new_output_shapes;
bool output_removed = false; bool output_removed = false;
for (auto result : if_op.getResults()) { for (auto result : if_op.getResults()) {
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) { if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
old_to_new_output_indices.push_back(non_resource_results++); old_to_new_output_indices.push_back(non_resource_results++);
if (!if_op.output_shapes().getValue().empty()) {
new_output_shapes.push_back(
if_op.output_shapes().getValue()[result.getResultNumber()]);
}
continue; continue;
} }
old_to_new_output_indices.push_back(-1); old_to_new_output_indices.push_back(-1);
@ -781,8 +776,7 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(), auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
then_branch.getType().getResults(), then_branch.getType().getResults(),
new_operands, if_op.getAttrs()); new_operands, if_op.getAttrs());
// Prepare for AddLoadsStoresOutsideControlFlowOp() and update // Prepare for AddLoadsStoresOutsideControlFlowOp()
// new_output_shapes.
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>> llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index; arg_data_type_and_updated_output_index;
for (const auto& entry : remaining_resource_data_types) { for (const auto& entry : remaining_resource_data_types) {
@ -792,14 +786,9 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
: new_output_it->getSecond(); : new_output_it->getSecond();
arg_data_type_and_updated_output_index[entry.getFirst() + 1] = { arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
entry.getSecond(), update_index}; entry.getSecond(), update_index};
if (!if_op.output_shapes().getValue().empty() && update_index >= 0) {
new_output_shapes.push_back(
tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond()));
}
} }
AddLoadsStoresOutsideControlFlowOp(new_if, AddLoadsStoresOutsideControlFlowOp(new_if,
arg_data_type_and_updated_output_index); arg_data_type_and_updated_output_index);
new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
// Replace uses. // Replace uses.
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) { for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
if (old_to_new_output_indices[i] >= 0) { if (old_to_new_output_indices[i] >= 0) {

View File

@ -254,22 +254,14 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
if (output_buffer_to_size.empty() && arg_no_changed) return success(); if (output_buffer_to_size.empty() && arg_no_changed) return success();
// Recreate the If op. // Recreate the If op.
auto new_if_operands = llvm::to_vector<8>(if_op.getOperands()); auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
auto new_output_shapes = llvm::to_vector<8>(if_op.output_shapes().getValue());
for (int64_t i = 1; i < if_op.getNumOperands(); ++i) { for (int64_t i = 1; i < if_op.getNumOperands(); ++i) {
auto it = buffer_to_size->find(if_op.getOperand(i)); auto it = buffer_to_size->find(if_op.getOperand(i));
if (it == buffer_to_size->end()) continue; if (it == buffer_to_size->end()) continue;
new_if_operands.push_back(it->getSecond().size); new_if_operands.push_back(it->getSecond().size);
if (!new_output_shapes.empty()) {
// Size is a scalar shape.
tensorflow::TensorShapeProto shape_proto;
new_output_shapes.push_back(builder.getStringAttr(
tensorflow::mangling_util::MangleShape(shape_proto)));
}
} }
auto new_if = OpBuilder(if_op).create<TF::IfOp>( auto new_if = OpBuilder(if_op).create<TF::IfOp>(
if_op.getLoc(), then_branch.getType().getResults(), new_if_operands, if_op.getLoc(), then_branch.getType().getResults(), new_if_operands,
if_op.getAttrs()); if_op.getAttrs());
new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
for (const auto& entry : output_buffer_to_size) { for (const auto& entry : output_buffer_to_size) {
(*buffer_to_size)[new_if.getResult(std::get<0>(entry))] = { (*buffer_to_size)[new_if.getResult(std::get<0>(entry))] = {
new_if.getResult(std::get<1>(entry)), std::get<2>(entry)}; new_if.getResult(std::get<1>(entry)), std::get<2>(entry)};