From 60c828a70ec0aad85dbd4077a1c84f8f9da88615 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 28 May 2020 15:12:28 -0700 Subject: [PATCH] Eliminate output_types from If/IfRegion ODS specs - Also eliminate output_types attribute from several test cases - This attribute may still be present on these ops since the importer seems to generate them. - Added a test to verify that values generated on one branch of the if cannot be consumed on the other branch PiperOrigin-RevId: 313668390 Change-Id: I97bed79f52f6694ead1931a64c411686067d2800 --- .../lite/tests/lower-static-tensor-list.mlir | 2 +- .../transforms/lower_static_tensor_list.cc | 1 - .../mlir/tensorflow/ir/tf_attributes.h | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops.td | 3 -- .../tests/promote_resources_to_args.mlir | 4 +-- .../tests/resource-device-inference.mlir | 2 +- .../tensorflow/tests/resource_op_lifting.mlir | 8 ++--- .../tensorflow/tests/shape_inference.mlir | 14 ++++----- .../mlir/tensorflow/tests/tf-ops.mlir | 30 +++++++++++++++++++ .../transforms/resource_op_lifting.cc | 13 +------- .../tensor_list_ops_decomposition.cc | 8 ----- 11 files changed, 47 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 9b1eeab3d7c..a7fb5b1666e 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -292,7 +292,7 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten // CHECK: [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor, tensor) -> tensor // CHECK: [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor, tensor) -> tensor // CHECK: [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor -// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, output_shapes = [], then_branch = @cond_true} : (tensor, tensor<3x10xf32>, tensor, tensor, tensor) -> tensor +// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor<3x10xf32>, tensor, tensor, tensor) -> tensor // CHECK: return [[RESULT]] : tensor } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 45b8c9e5fb2..2498a732a86 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -577,7 +577,6 @@ struct ConvertTensorListResize ArrayRef({input_handle, input_shape, size_diff, size}), /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op), /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op), - /*output_shapes=*/rewriter.getArrayAttr({}), /*is_stateless=*/rewriter.getBoolAttr(true)); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index ba67d6cb671..1edc7356ab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -26,7 +26,7 @@ namespace TF { namespace AttrKind { -// List of supported custom TensorFlow Attributes kinds, necessary for +// List of supported custom TensorFlow Attribute kinds, necessary for // isa/dyn_cast. enum Kind { FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 7f31c274a09..51b9dd862ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -188,7 +188,6 @@ else_branch: A function that takes 'inputs' and returns a list of FlatSymbolRefAttr:$then_branch, FlatSymbolRefAttr:$else_branch, - DefaultValuedAttr:$output_shapes, // Used to map StatelessIf and If op defined in TensorFlow to a common op. BoolAttr:$is_stateless @@ -248,8 +247,6 @@ else_branch: A region that computes the outputs of the op if cond = false. let arguments = (ins TF_Tensor:$cond, - DefaultValuedAttr:$output_shapes, - // Used to map StatelessIf and If op defined in TensorFlow to a common op. BoolAttr:$is_stateless ); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 60663f4bd4a..59c93a66d12 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -145,8 +145,8 @@ func @main(%arg0: tensor) -> tensor<2xf32> attributes {tf.entry_function = { %2 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor %3 = "tf.Less"(%2, %0) : (tensor, tensor) -> tensor %4 = "tf.If"(%3, %1, %2) {Tcond = i1, Tin = ["tfdtype$DT_RESOURCE", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], - else_branch = @cond_false, is_stateless = false, output_shapes = [#tf.shape<>], - then_branch = @cond_true} : (tensor, tensor>>, tensor) -> tensor + else_branch = @cond_false, is_stateless = false,then_branch = @cond_true} : + (tensor, tensor>>, tensor) -> tensor %5 = "tf.Identity"(%4) : (tensor) -> tensor %6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xf32> return %6 : tensor<2xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir index c98e40fed05..60eded3de7e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir @@ -217,7 +217,7 @@ func @error_on_conflict_multiple_callers( // expected-error@above {{Conflicting device assignment for resource}} then_branch = @if_then_and_else, else_branch = @if_then_and_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> () tf_executor.yield diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 9e7358ab2f5..b19033ce5b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -420,7 +420,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]]) %3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = [#tf.shape<>, #tf.shape<4>], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<4xf32>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) @@ -468,7 +468,7 @@ func @cluster_with_nested_if(%arg0: tensor) -> tensor { %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]]) %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]]) @@ -488,7 +488,7 @@ func @if_then(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.re // CHECK-NEXT: %[[IIF:.*]] = "tf.If"(%[[TARG0]], %[[TARG0]]) %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor %3 = "tf.If"(%read, %arg0) {then_branch = @inner_if_then, else_branch = @inner_if_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK-NEXT: return %[[IIF]] @@ -526,7 +526,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %2 = "tf_device.cluster"() ( { // expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = [#tf.shape<>], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) %4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource>>) -> tensor<4xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 3cdade8da59..e3766a7d9d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -102,7 +102,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @shape_from_if_to_branch_functions func @shape_from_if_to_branch_functions(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } @@ -184,16 +184,16 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @invalid_function_reused_by_control_flows func @invalid_function_reused_by_control_flows(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - // expected-warning @+1 {{unable to refine shape}} - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> - // expected-warning @+1 {{unable to refine shape}} - %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // expected-warning @+1 {{unable to refine shape}} + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // expected-warning @+1 {{unable to refine shape}} + %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } // CHECK-LABEL: func @reused_if_then_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return // CHECK-SAME: tensor<*xf32> @@ -202,7 +202,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @reused_if_else_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 2e00dd6a517..20f7c5b9ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1048,6 +1048,36 @@ func @testIfRegionOpYieldMismatchElse(%arg0: tensor, %arg1: tensor<2xf32>) - // ----- +// value generated in one branch cannot be consumed in the other branch +func @testIfRegionElseConsumingThen(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + // expected-error @+1 {{use of undeclared SSA value name}} + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionThenConsumingElse(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.IfRegion"(%arg0) ({ + // expected-error @+1 {{does not dominate this use}} + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + // expected-note @+1 {{operand defined here}} + %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + // Test valid tf.MatrixBandPart // CHECK-LABEL: func @testValidMatrixBandPartOp func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 82bc612b1f8..c1e5241a1f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -700,15 +700,10 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, // Erase the resource outputs from the branches. int64_t non_resource_results = 0; llvm::SmallVector old_to_new_output_indices; - llvm::SmallVector new_output_shapes; bool output_removed = false; for (auto result : if_op.getResults()) { if (!getElementTypeOrSelf(result.getType()).isa()) { old_to_new_output_indices.push_back(non_resource_results++); - if (!if_op.output_shapes().getValue().empty()) { - new_output_shapes.push_back( - if_op.output_shapes().getValue()[result.getResultNumber()]); - } continue; } old_to_new_output_indices.push_back(-1); @@ -781,8 +776,7 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, auto new_if = builder.create(if_op.getLoc(), then_branch.getType().getResults(), new_operands, if_op.getAttrs()); - // Prepare for AddLoadsStoresOutsideControlFlowOp() and update - // new_output_shapes. + // Prepare for AddLoadsStoresOutsideControlFlowOp() llvm::SmallDenseMap> arg_data_type_and_updated_output_index; for (const auto& entry : remaining_resource_data_types) { @@ -792,14 +786,9 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, : new_output_it->getSecond(); arg_data_type_and_updated_output_index[entry.getFirst() + 1] = { entry.getSecond(), update_index}; - if (!if_op.output_shapes().getValue().empty() && update_index >= 0) { - new_output_shapes.push_back( - tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond())); - } } AddLoadsStoresOutsideControlFlowOp(new_if, arg_data_type_and_updated_output_index); - new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); // Replace uses. for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) { if (old_to_new_output_indices[i] >= 0) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 6e27823191b..b2203c890e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -254,22 +254,14 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, if (output_buffer_to_size.empty() && arg_no_changed) return success(); // Recreate the If op. auto new_if_operands = llvm::to_vector<8>(if_op.getOperands()); - auto new_output_shapes = llvm::to_vector<8>(if_op.output_shapes().getValue()); for (int64_t i = 1; i < if_op.getNumOperands(); ++i) { auto it = buffer_to_size->find(if_op.getOperand(i)); if (it == buffer_to_size->end()) continue; new_if_operands.push_back(it->getSecond().size); - if (!new_output_shapes.empty()) { - // Size is a scalar shape. - tensorflow::TensorShapeProto shape_proto; - new_output_shapes.push_back(builder.getStringAttr( - tensorflow::mangling_util::MangleShape(shape_proto))); - } } auto new_if = OpBuilder(if_op).create( if_op.getLoc(), then_branch.getType().getResults(), new_if_operands, if_op.getAttrs()); - new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_if.getResult(std::get<0>(entry))] = { new_if.getResult(std::get<1>(entry)), std::get<2>(entry)};