Check that the function caller op can be inlined.
TPUPartitionedCall inherits the CallOpInterface but should not be inlined because it dispatches a function on a single TPU core based on the input argument. Similar restriction may be true of other op that inherit CallOpInterface and they should be included here as well. PiperOrigin-RevId: 338289266 Change-Id: Ie78abc6c66973bf76fa4cdccb8c86c286c6e72bf
This commit is contained in:
parent
1057c67cbd
commit
8e84212c47
@ -90,7 +90,9 @@ bool HasSingleUse(FuncOp func) {
|
||||
// Inspect function uses in the containing module and all parent
|
||||
// modules.
|
||||
bool use_seen = false;
|
||||
for (; module; module = module.getParentOfType<ModuleOp>()) {
|
||||
for (; module; module = func.isPrivate()
|
||||
? nullptr
|
||||
: module.getParentOfType<ModuleOp>()) {
|
||||
auto func_uses_optional =
|
||||
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
// Found an unknown use.
|
||||
@ -105,15 +107,36 @@ bool HasSingleUse(FuncOp func) {
|
||||
|
||||
// This is the first use seen.
|
||||
use_seen = true;
|
||||
|
||||
// If the function is private, no need to inspect parent modules.
|
||||
if (func.isPrivate()) break;
|
||||
}
|
||||
|
||||
// No multiple uses seen.
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if the caller ops can be inlined.
|
||||
bool HasInlinableUsers(FuncOp func) {
|
||||
// Return false if unexpected IR structure seen.
|
||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
||||
if (!module) return false;
|
||||
|
||||
// Inspect function uses in the containing module and all parent
|
||||
// modules.
|
||||
for (; module; module = func.isPrivate()
|
||||
? nullptr
|
||||
: module.getParentOfType<ModuleOp>()) {
|
||||
auto func_uses_optional =
|
||||
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
// Found an unknown use.
|
||||
if (!func_uses_optional) return false;
|
||||
|
||||
for (auto &use : func_uses_optional.getValue())
|
||||
if (isa<TPUPartitionedCallOp>(use.getUser())) return false;
|
||||
}
|
||||
|
||||
// All caller ops that can be inlined.
|
||||
return true;
|
||||
}
|
||||
|
||||
struct TFConstantFoldInterface : public DialectFoldInterface {
|
||||
TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
|
||||
LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
|
||||
@ -153,11 +176,14 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
BlockAndValueMapping &) const final {
|
||||
// An op is legal to inline if either of the following conditions is true:
|
||||
// (a) Its legal to duplicate the Op.
|
||||
// (a) The Op is inside a single use function. If that function is inlined,
|
||||
// (b) The Op is inside a single use function. If that function is inlined,
|
||||
// post inlining, the function will be dead and eliminated from the IR.
|
||||
// So there won't be any code duplication.
|
||||
// plus the function caller op can be replaced by inlined ops.
|
||||
FuncOp func = op->getParentOfType<FuncOp>();
|
||||
return !func || TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
|
||||
if (!func) return true;
|
||||
if (!HasInlinableUsers(func)) return false;
|
||||
return TensorFlowDialect::CanDuplicate(op) || HasSingleUse(func);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
@ -15,6 +15,23 @@ func @inline_simple() -> tensor<2xi32> {
|
||||
return %result : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Test that TPUParitionedCallOp is not inlined.
|
||||
|
||||
func @simple_callee() -> tensor<2xi32> attributes {sym_visibility = "private"} {
|
||||
%cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
return %cst : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dont_inline_tpu_partitioned_call(
|
||||
func @dont_inline_tpu_partitioned_call() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %[[ORDINAL:.*]] = "tf.TPUOrdinalSelector"
|
||||
// CHECK-NEXT: %[[PARTITIONED_CALL:.*]] = "tf.TPUPartitionedCall"(%[[ORDINAL]])
|
||||
// CHECK-NEXT: return %[[PARTITIONED_CALL]]
|
||||
%0 = "tf.TPUOrdinalSelector"() {device = ""} : () -> tensor<?xi32>
|
||||
%result = "tf.TPUPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @simple_callee} : (tensor<?xi32>) -> tensor<2xi32>
|
||||
return %result : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Check that TF call operations can be inlined, even when the shape of the
|
||||
// argument or result is different than the called function.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user