Add support for called function in tpuv1 inlining pass

When a callgraph is involved, we need to inline back the called
functions as well before deleting the nested module.

PiperOrigin-RevId: 295885585
Change-Id: I61a4274e06a3009e97ca800cc2ed60591e522149
This commit is contained in:
Mehdi Amini 2020-02-18 20:43:05 -08:00 committed by TensorFlower Gardener
parent f9e9fb9de2
commit 9771b11027
4 changed files with 98 additions and 0 deletions

View File

@ -41,11 +41,52 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Support/STLExtras.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/logging.h"
namespace mlir {
namespace tf_device {
//===----------------------------------------------------------------------===//
// TF Device Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TFInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Defines the legality of inlining TF Device operations.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
// For now, enable inlining all operations.
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
// Attempts to materialize a conversion for a type mismatch between a call
// from this dialect, and a callable region. This method should generate an
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
// This is just re-using the same logic as the TensorFlow dialect right now.
Operation* materializeCallConversion(OpBuilder& builder, Value input,
Type result_type,
Location conversion_loc) const final {
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
return nullptr;
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
/*truncate=*/builder.getBoolAttr(false));
}
};
} // end anonymous namespace
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
: Dialect(/*name=*/"tf_device", context) {
addOperations<
@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
>();
addOperations<ParallelExecuteOp>();
addInterfaces<TFInlinerInterface>();
}
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,44 @@
// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail
// CHECK-NOT: tf.PartitionedCall
// CHECK-NOT: module @_tpu_v1_compat_outlined
module {
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
%0:4 = tf_executor.graph {
%outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
return %0#0 : tensor<i32>
}
module @_tpu_v1_compat_outlined {
func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> ()
%0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
%1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
%3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
return %0, %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
func @while_body_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_with_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i32> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
func @while_cond_without_cluster_attr(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @callee_func(%arg0: tensor<i32>) -> tensor<i1> {
%0 = "some.op"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
}
}

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() {
call_op.emitOpError() << "Failed to inline\n";
return WalkResult::interrupt();
}
called_func.erase();
call_op.erase();
return WalkResult::advance();
});
if (walk_result.wasInterrupted()) return signalPassFailure();
// Move all remaining nested functions back into the parent module.
Block &nested_block = nested_module->getRegion(0).front();
for (FuncOp func_op :
llvm::make_early_inc_range(nested_block.getOps<FuncOp>())) {
if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) {
nested_block.getOperations().remove(func_op.getOperation());
symbol_table.insert(func_op.getOperation());
}
}
nested_module->erase();
}