Add support for called function in tpuv1 inlining pass
When a callgraph is involved, we need to inline back the called functions as well before deleting the nested module. PiperOrigin-RevId: 295885585 Change-Id: I61a4274e06a3009e97ca800cc2ed60591e522149
This commit is contained in:
parent
f9e9fb9de2
commit
9771b11027
@ -41,11 +41,52 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_device {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Device Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Defines the legality of inlining TF Device operations.
|
||||
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||
// For now, enable inlining all operations.
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Attempts to materialize a conversion for a type mismatch between a call
|
||||
// from this dialect, and a callable region. This method should generate an
|
||||
// operation that takes 'input' as the only operand, and produces a single
|
||||
// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
// should be returned.
|
||||
// This is just re-using the same logic as the TensorFlow dialect right now.
|
||||
Operation* materializeCallConversion(OpBuilder& builder, Value input,
|
||||
Type result_type,
|
||||
Location conversion_loc) const final {
|
||||
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
|
||||
return nullptr;
|
||||
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
||||
/*truncate=*/builder.getBoolAttr(false));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
: Dialect(/*name=*/"tf_device", context) {
|
||||
addOperations<
|
||||
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
>();
|
||||
|
||||
addOperations<ParallelExecuteOp>();
|
||||
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -0,0 +1,44 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-NOT: tf.PartitionedCall
|
||||
// CHECK-NOT: module @_tpu_v1_compat_outlined
|
||||
|
||||
module {
|
||||
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||
%0:4 = tf_executor.graph {
|
||||
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
module @_tpu_v1_compat_outlined {
|
||||
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
|
||||
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
}
|
||||
}
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
|
||||
call_op.emitOpError() << "Failed to inline\n";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
called_func.erase();
|
||||
call_op.erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||
// Move all remaining nested functions back into the parent module.
|
||||
Block &nested_block = nested_module->getRegion(0).front();
|
||||
for (FuncOp func_op :
|
||||
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
|
||||
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) {
|
||||
nested_block.getOperations().remove(func_op.getOperation());
|
||||
symbol_table.insert(func_op.getOperation());
|
||||
}
|
||||
}
|
||||
nested_module->erase();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user