Propagate shapes to PartitionedCall op's function in shape inference pass
Similar to the shape propagation for If and While control flow ops. PiperOrigin-RevId: 290868974 Change-Id: Id4dc95196cb97f5f76ef310925c79a399f4ad99d
This commit is contained in:
parent
933fa7cfeb
commit
69905619c9
@ -149,6 +149,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>
|
||||
}
|
||||
|
||||
func @partitioned_call(%arg0: tensor<i32>) -> tensor<*xi32> {
|
||||
%0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @partitioned_call_func} : (tensor<i32>) -> (tensor<*xi32>)
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @partitioned_call_func
|
||||
// CHECK-SAME: (%arg0: tensor<i32>) -> tensor<i32>
|
||||
func @partitioned_call_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<i32>
|
||||
return %arg0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @invalid_function_reused_by_control_flows
|
||||
func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
// expected-warning @+1 {{unable to refine shape}}
|
||||
|
@ -408,22 +408,15 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult PropagateShapeToIfWhileOpFunctions(
|
||||
OpTy op, llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
|
||||
LogicalResult PropagateShapeToFunctions(
|
||||
ModuleOp module, Operation::operand_type_range input_types,
|
||||
llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
llvm::SmallVector<Type, 4> input_types;
|
||||
input_types.reserve(std::distance(op.input().begin(), op.input().end()));
|
||||
for (Value v : op.input()) {
|
||||
input_types.push_back(v.getType());
|
||||
}
|
||||
|
||||
ModuleOp module = op.template getParentOfType<ModuleOp>();
|
||||
|
||||
bool success = true;
|
||||
auto types = llvm::to_vector<4>(input_types);
|
||||
for (auto func_name : func_names) {
|
||||
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
|
||||
if (failed(RefineShapeForControlFlowFunc(func, input_types, graph_version,
|
||||
if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
|
||||
max_iteration))) {
|
||||
success = false;
|
||||
}
|
||||
@ -434,14 +427,20 @@ LogicalResult PropagateShapeToIfWhileOpFunctions(
|
||||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||
int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
return PropagateShapeToIfWhileOpFunctions<TF::IfOp>(
|
||||
if_op, {if_op.then_branch(), if_op.else_branch()}, graph_version,
|
||||
return PropagateShapeToFunctions(
|
||||
module, llvm::drop_begin(if_op.getOperandTypes(), 1),
|
||||
{if_op.then_branch(), if_op.else_branch()}, graph_version,
|
||||
max_iteration);
|
||||
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
return PropagateShapeToIfWhileOpFunctions<TF::WhileOp>(
|
||||
while_op, {while_op.cond(), while_op.body()}, graph_version,
|
||||
max_iteration);
|
||||
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
|
||||
{while_op.cond(), while_op.body()},
|
||||
graph_version, max_iteration);
|
||||
} else if (auto call_op = dyn_cast<TF::PartitionedCallOp>(op)) {
|
||||
return PropagateShapeToFunctions(module, call_op.getOperandTypes(),
|
||||
{call_op.f()}, graph_version,
|
||||
max_iteration);
|
||||
}
|
||||
|
||||
// TODO(ycao): Implement support for Call op, including function reuse.
|
||||
|
Loading…
Reference in New Issue
Block a user