Propagate shapes to PartitionedCall op's function in shape inference pass

Similar to the shape propagation for If and While control flow ops.

PiperOrigin-RevId: 290868974
Change-Id: Id4dc95196cb97f5f76ef310925c79a399f4ad99d
This commit is contained in:
Smit Hinsu 2020-01-21 18:35:26 -08:00 committed by TensorFlower Gardener
parent 933fa7cfeb
commit 69905619c9
2 changed files with 29 additions and 17 deletions

View File

@ -149,6 +149,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>> return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>
} }
func @partitioned_call(%arg0: tensor<i32>) -> tensor<*xi32> {
%0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @partitioned_call_func} : (tensor<i32>) -> (tensor<*xi32>)
return %0 : tensor<*xi32>
}
// CHECK-LABEL: func @partitioned_call_func
// CHECK-SAME: (%arg0: tensor<i32>) -> tensor<i32>
func @partitioned_call_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: return
// CHECK-SAME: tensor<i32>
return %arg0 : tensor<*xi32>
}
// CHECK-LABEL: func @invalid_function_reused_by_control_flows // CHECK-LABEL: func @invalid_function_reused_by_control_flows
func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
// expected-warning @+1 {{unable to refine shape}} // expected-warning @+1 {{unable to refine shape}}

View File

@ -408,22 +408,15 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
return success(); return success();
} }
template <typename OpTy> LogicalResult PropagateShapeToFunctions(
LogicalResult PropagateShapeToIfWhileOpFunctions( ModuleOp module, Operation::operand_type_range input_types,
OpTy op, llvm::ArrayRef<StringRef> func_names, int64_t graph_version, llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
int64_t max_iteration) { int64_t max_iteration) {
llvm::SmallVector<Type, 4> input_types;
input_types.reserve(std::distance(op.input().begin(), op.input().end()));
for (Value v : op.input()) {
input_types.push_back(v.getType());
}
ModuleOp module = op.template getParentOfType<ModuleOp>();
bool success = true; bool success = true;
auto types = llvm::to_vector<4>(input_types);
for (auto func_name : func_names) { for (auto func_name : func_names) {
FuncOp func = module.lookupSymbol<FuncOp>(func_name); FuncOp func = module.lookupSymbol<FuncOp>(func_name);
if (failed(RefineShapeForControlFlowFunc(func, input_types, graph_version, if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
max_iteration))) { max_iteration))) {
success = false; success = false;
} }
@ -434,14 +427,20 @@ LogicalResult PropagateShapeToIfWhileOpFunctions(
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
int64_t graph_version, int64_t graph_version,
int64_t max_iteration) { int64_t max_iteration) {
ModuleOp module = op->getParentOfType<ModuleOp>();
if (auto if_op = dyn_cast<TF::IfOp>(op)) { if (auto if_op = dyn_cast<TF::IfOp>(op)) {
return PropagateShapeToIfWhileOpFunctions<TF::IfOp>( return PropagateShapeToFunctions(
if_op, {if_op.then_branch(), if_op.else_branch()}, graph_version, module, llvm::drop_begin(if_op.getOperandTypes(), 1),
{if_op.then_branch(), if_op.else_branch()}, graph_version,
max_iteration); max_iteration);
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) { } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
return PropagateShapeToIfWhileOpFunctions<TF::WhileOp>( return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
while_op, {while_op.cond(), while_op.body()}, graph_version, {while_op.cond(), while_op.body()},
max_iteration); graph_version, max_iteration);
} else if (auto call_op = dyn_cast<TF::PartitionedCallOp>(op)) {
return PropagateShapeToFunctions(module, call_op.getOperandTypes(),
{call_op.f()}, graph_version,
max_iteration);
} }
// TODO(ycao): Implement support for Call op, including function reuse. // TODO(ycao): Implement support for Call op, including function reuse.