Propagate shapes to PartitionedCall op's function in shape inference pass

Similar to the shape propagation for If and While control flow ops.

PiperOrigin-RevId: 290868974
Change-Id: Id4dc95196cb97f5f76ef310925c79a399f4ad99d
This commit is contained in:
Smit Hinsu 2020-01-21 18:35:26 -08:00 committed by TensorFlower Gardener
parent 933fa7cfeb
commit 69905619c9
2 changed files with 29 additions and 17 deletions

View File

@ -149,6 +149,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>
}
func @partitioned_call(%arg0: tensor<i32>) -> tensor<*xi32> {
%0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @partitioned_call_func} : (tensor<i32>) -> (tensor<*xi32>)
return %0 : tensor<*xi32>
}
// CHECK-LABEL: func @partitioned_call_func
// CHECK-SAME: (%arg0: tensor<i32>) -> tensor<i32>
func @partitioned_call_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: return
// CHECK-SAME: tensor<i32>
return %arg0 : tensor<*xi32>
}
// CHECK-LABEL: func @invalid_function_reused_by_control_flows
func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
// expected-warning @+1 {{unable to refine shape}}

View File

@ -408,22 +408,15 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
return success();
}
template <typename OpTy>
LogicalResult PropagateShapeToIfWhileOpFunctions(
OpTy op, llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
LogicalResult PropagateShapeToFunctions(
ModuleOp module, Operation::operand_type_range input_types,
llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
int64_t max_iteration) {
llvm::SmallVector<Type, 4> input_types;
input_types.reserve(std::distance(op.input().begin(), op.input().end()));
for (Value v : op.input()) {
input_types.push_back(v.getType());
}
ModuleOp module = op.template getParentOfType<ModuleOp>();
bool success = true;
auto types = llvm::to_vector<4>(input_types);
for (auto func_name : func_names) {
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
if (failed(RefineShapeForControlFlowFunc(func, input_types, graph_version,
if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
max_iteration))) {
success = false;
}
@ -434,14 +427,20 @@ LogicalResult PropagateShapeToIfWhileOpFunctions(
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
int64_t graph_version,
int64_t max_iteration) {
ModuleOp module = op->getParentOfType<ModuleOp>();
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
return PropagateShapeToIfWhileOpFunctions<TF::IfOp>(
if_op, {if_op.then_branch(), if_op.else_branch()}, graph_version,
return PropagateShapeToFunctions(
module, llvm::drop_begin(if_op.getOperandTypes(), 1),
{if_op.then_branch(), if_op.else_branch()}, graph_version,
max_iteration);
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
return PropagateShapeToIfWhileOpFunctions<TF::WhileOp>(
while_op, {while_op.cond(), while_op.body()}, graph_version,
max_iteration);
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
{while_op.cond(), while_op.body()},
graph_version, max_iteration);
} else if (auto call_op = dyn_cast<TF::PartitionedCallOp>(op)) {
return PropagateShapeToFunctions(module, call_op.getOperandTypes(),
{call_op.f()}, graph_version,
max_iteration);
}
// TODO(ycao): Implement support for Call op, including function reuse.