diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 582f2237d01..ab9d2a44f63 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -149,6 +149,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor>> } + func @partitioned_call(%arg0: tensor) -> tensor<*xi32> { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @partitioned_call_func} : (tensor) -> (tensor<*xi32>) + return %0 : tensor<*xi32> + } + + // CHECK-LABEL: func @partitioned_call_func + // CHECK-SAME: (%arg0: tensor) -> tensor + func @partitioned_call_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: return + // CHECK-SAME: tensor + return %arg0 : tensor<*xi32> + } + // CHECK-LABEL: func @invalid_function_reused_by_control_flows func @invalid_function_reused_by_control_flows(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { // expected-warning @+1 {{unable to refine shape}} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 6a2d89c9ee3..631c15f5bdf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -408,22 +408,15 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, return success(); } -template -LogicalResult PropagateShapeToIfWhileOpFunctions( - OpTy op, llvm::ArrayRef func_names, int64_t graph_version, +LogicalResult PropagateShapeToFunctions( + ModuleOp module, Operation::operand_type_range input_types, + llvm::ArrayRef func_names, int64_t graph_version, int64_t max_iteration) { - llvm::SmallVector input_types; - input_types.reserve(std::distance(op.input().begin(), op.input().end())); - for (Value v : op.input()) { - input_types.push_back(v.getType()); - } - - ModuleOp module = op.template getParentOfType(); - bool success = true; + auto types = llvm::to_vector<4>(input_types); for (auto func_name : func_names) { FuncOp func = module.lookupSymbol(func_name); - if (failed(RefineShapeForControlFlowFunc(func, input_types, graph_version, + if (failed(RefineShapeForControlFlowFunc(func, types, graph_version, max_iteration))) { success = false; } @@ -434,14 +427,20 @@ LogicalResult PropagateShapeToIfWhileOpFunctions( LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, int64_t graph_version, int64_t max_iteration) { + ModuleOp module = op->getParentOfType(); if (auto if_op = dyn_cast(op)) { - return PropagateShapeToIfWhileOpFunctions( - if_op, {if_op.then_branch(), if_op.else_branch()}, graph_version, + return PropagateShapeToFunctions( + module, llvm::drop_begin(if_op.getOperandTypes(), 1), + {if_op.then_branch(), if_op.else_branch()}, graph_version, max_iteration); } else if (auto while_op = dyn_cast(op)) { - return PropagateShapeToIfWhileOpFunctions( - while_op, {while_op.cond(), while_op.body()}, graph_version, - max_iteration); + return PropagateShapeToFunctions(module, while_op.getOperandTypes(), + {while_op.cond(), while_op.body()}, + graph_version, max_iteration); + } else if (auto call_op = dyn_cast(op)) { + return PropagateShapeToFunctions(module, call_op.getOperandTypes(), + {call_op.f()}, graph_version, + max_iteration); } // TODO(ycao): Implement support for Call op, including function reuse.