Add support for called function in tpuv1 inlining pass
When a callgraph is involved, we need to inline back the called functions as well before deleting the nested module. PiperOrigin-RevId: 295885585 Change-Id: I61a4274e06a3009e97ca800cc2ed60591e522149
This commit is contained in:
parent
f9e9fb9de2
commit
9771b11027
@ -41,11 +41,52 @@ limitations under the License.
|
|||||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||||
|
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace tf_device {
|
namespace tf_device {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TF Device Dialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TFInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Analysis Hooks
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Defines the legality of inlining TF Device operations.
|
||||||
|
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||||
|
// For now, enable inlining all operations.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Transformation Hooks
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Attempts to materialize a conversion for a type mismatch between a call
|
||||||
|
// from this dialect, and a callable region. This method should generate an
|
||||||
|
// operation that takes 'input' as the only operand, and produces a single
|
||||||
|
// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||||
|
// should be returned.
|
||||||
|
// This is just re-using the same logic as the TensorFlow dialect right now.
|
||||||
|
Operation* materializeCallConversion(OpBuilder& builder, Value input,
|
||||||
|
Type result_type,
|
||||||
|
Location conversion_loc) const final {
|
||||||
|
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
|
||||||
|
return nullptr;
|
||||||
|
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
||||||
|
/*truncate=*/builder.getBoolAttr(false));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||||
: Dialect(/*name=*/"tf_device", context) {
|
: Dialect(/*name=*/"tf_device", context) {
|
||||||
addOperations<
|
addOperations<
|
||||||
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
|||||||
>();
|
>();
|
||||||
|
|
||||||
addOperations<ParallelExecuteOp>();
|
addOperations<ParallelExecuteOp>();
|
||||||
|
|
||||||
|
addInterfaces<TFInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@ -0,0 +1,44 @@
|
|||||||
|
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-NOT: tf.PartitionedCall
|
||||||
|
// CHECK-NOT: module @_tpu_v1_compat_outlined
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||||
|
%0:4 = tf_executor.graph {
|
||||||
|
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||||
|
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||||
|
}
|
||||||
|
return %0#0 : tensor<i32>
|
||||||
|
}
|
||||||
|
module @_tpu_v1_compat_outlined {
|
||||||
|
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||||
|
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
|
||||||
|
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||||
|
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||||
|
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||||
|
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||||
|
}
|
||||||
|
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||||
|
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
|
||||||
|
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
|
||||||
|
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
|
|||||||
call_op.emitOpError() << "Failed to inline\n";
|
call_op.emitOpError() << "Failed to inline\n";
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
}
|
}
|
||||||
|
called_func.erase();
|
||||||
call_op.erase();
|
call_op.erase();
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
});
|
});
|
||||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||||
|
// Move all remaining nested functions back into the parent module.
|
||||||
|
Block &nested_block = nested_module->getRegion(0).front();
|
||||||
|
for (FuncOp func_op :
|
||||||
|
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
|
||||||
|
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) {
|
||||||
|
nested_block.getOperations().remove(func_op.getOperation());
|
||||||
|
symbol_table.insert(func_op.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
nested_module->erase();
|
nested_module->erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user