Check that the function caller op can be inlined.

TPUPartitionedCall inherits the CallOpInterface but should not be inlined because it dispatches a function on a single TPU core based on the input argument.  Similar restriction may be true of other op that inherit CallOpInterface and they should be included here as well.

PiperOrigin-RevId: 338289266
Change-Id: Ie78abc6c66973bf76fa4cdccb8c86c286c6e72bf
This commit is contained in:
Ken Franko 2020-10-21 10:22:50 -07:00 committed by TensorFlower Gardener
parent 1057c67cbd
commit 8e84212c47
2 changed files with 49 additions and 6 deletions

View File

@ -90,7 +90,9 @@ bool HasSingleUse(FuncOp func) {
// Inspect function uses in the containing module and all parent
// modules.
bool use_seen = false;
for (; module; module = module.getParentOfType<ModuleOp>()) {
for (; module; module = func.isPrivate()
? nullptr
: module.getParentOfType<ModuleOp>()) {
auto func_uses_optional =
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
// Found an unknown use.
@ -105,15 +107,36 @@ bool HasSingleUse(FuncOp func) {
// This is the first use seen.
use_seen = true;
// If the function is private, no need to inspect parent modules.
if (func.isPrivate()) break;
}
// No multiple uses seen.
return true;
}
// Returns true if the caller ops can be inlined.
bool HasInlinableUsers(FuncOp func) {
// Return false if unexpected IR structure seen.
ModuleOp module = func.getParentOfType<ModuleOp>();
if (!module) return false;
// Inspect function uses in the containing module and all parent
// modules.
for (; module; module = func.isPrivate()
? nullptr
: module.getParentOfType<ModuleOp>()) {
auto func_uses_optional =
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
// Found an unknown use.
if (!func_uses_optional) return false;
for (auto &use : func_uses_optional.getValue())
if (isa<TPUPartitionedCallOp>(use.getUser())) return false;
}
// All caller ops that can be inlined.
return true;
}
struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
@ -153,11 +176,14 @@ struct TFInlinerInterface : public DialectInlinerInterface {
BlockAndValueMapping &) const final {
// An op is legal to inline if either of the following conditions is true:
// (a) Its legal to duplicate the Op.
// (a) The Op is inside a single use function. If that function is inlined,
// (b) The Op is inside a single use function. If that function is inlined,
// post inlining, the function will be dead and eliminated from the IR.
// So there won't be any code duplication.
// plus the function caller op can be replaced by inlined ops.
FuncOp func = op->getParentOfType<FuncOp>();
return !func || TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
if (!func) return true;
if (!HasInlinableUsers(func)) return false;
return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
}
//===--------------------------------------------------------------------===//

View File

@ -15,6 +15,23 @@ func @inline_simple() -> tensor<2xi32> {
return %result : tensor<2xi32>
}
// Test that TPUParitionedCallOp is not inlined.
func @simple_callee() -> tensor<2xi32> attributes {sym_visibility = "private"} {
%cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32>
return %cst : tensor<2xi32>
}
// CHECK-LABEL: func @dont_inline_tpu_partitioned_call(
func @dont_inline_tpu_partitioned_call() -> tensor<2xi32> {
// CHECK-NEXT: %[[ORDINAL:.*]] = "tf.TPUOrdinalSelector"
// CHECK-NEXT: %[[PARTITIONED_CALL:.*]] = "tf.TPUPartitionedCall"(%[[ORDINAL]])
// CHECK-NEXT: return %[[PARTITIONED_CALL]]
%0 = "tf.TPUOrdinalSelector"() {device = ""} : () -> tensor<?xi32>
%result = "tf.TPUPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @simple_callee} : (tensor<?xi32>) -> tensor<2xi32>
return %result : tensor<2xi32>
}
// Check that TF call operations can be inlined, even when the shape of the
// argument or result is different than the called function.