diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 5c277eeb9db..c88ddaf7806 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -41,11 +41,52 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Support/STLExtras.h" // TF:llvm-project +#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/logging.h" namespace mlir { namespace tf_device { +//===----------------------------------------------------------------------===// +// TF Device Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TFInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + // Defines the legality of inlining TF Device operations. + bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { + // For now, enable inlining all operations. + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + // Attempts to materialize a conversion for a type mismatch between a call + // from this dialect, and a callable region. This method should generate an + // operation that takes 'input' as the only operand, and produces a single + // result of 'resultType'. If a conversion can not be generated, nullptr + // should be returned. + // This is just re-using the same logic as the TensorFlow dialect right now. + Operation* materializeCallConversion(OpBuilder& builder, Value input, + Type result_type, + Location conversion_loc) const final { + if (!result_type.isa() || !input.getType().isa()) + return nullptr; + return builder.create(conversion_loc, result_type, input, + /*truncate=*/builder.getBoolAttr(false)); + } +}; +} // end anonymous namespace + TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) : Dialect(/*name=*/"tf_device", context) { addOperations< @@ -54,6 +95,8 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) >(); addOperations(); + + addInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_inline_tpu_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir similarity index 100% rename from tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_inline_tpu_island.mlir rename to tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir new file mode 100644 index 00000000000..010b5346e1e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir @@ -0,0 +1,44 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-inlining | FileCheck %s --dump-input=fail + +// CHECK-NOT: tf.PartitionedCall +// CHECK-NOT: module @_tpu_v1_compat_outlined + +module { + func @control_input(%arg0: tensor) -> tensor { + %0:4 = tf_executor.graph { + %outputs:4, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0} : (tensor) -> (tensor, tensor, tensor, tensor) + tf_executor.fetch %outputs#0, %outputs#1, %outputs#2, %outputs#3 : tensor, tensor, tensor, tensor + } + return %0#0 : tensor + } + module @_tpu_v1_compat_outlined { + func @_tpu_v1_compat_outlined_func0(%arg0: tensor) -> (tensor, tensor, tensor, tensor) { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> () + %0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor) -> tensor + %2 = "tf.While"(%0) {body = @while_body_without_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "C", parallel_iterations = 10 : i64} : (tensor) -> tensor + %3 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_without_cluster_attr, is_stateless = false, name = "E", parallel_iterations = 10 : i64} : (tensor) -> tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor + } + func @while_body_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_with_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @while_body_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + func @while_cond_without_cluster_attr(%arg0: tensor) -> tensor { + %0 = "tf.PartionedCalledOp"(%arg0) {f = @callee_func} : (tensor) -> tensor + return %0 : tensor + } + func @callee_func(%arg0: tensor) -> tensor { + %0 = "some.op"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 80fcd52056d..9660367cb68 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" @@ -70,10 +71,20 @@ void TPUBridgeExecutorIslandInlining::runOnModule() { call_op.emitOpError() << "Failed to inline\n"; return WalkResult::interrupt(); } + called_func.erase(); call_op.erase(); return WalkResult::advance(); }); if (walk_result.wasInterrupted()) return signalPassFailure(); + // Move all remaining nested functions back into the parent module. + Block &nested_block = nested_module->getRegion(0).front(); + for (FuncOp func_op : + llvm::make_early_inc_range(nested_block.getOps())) { + if (!symbol_table.lookupSymbolIn(getModule(), func_op.getName())) { + nested_block.getOperations().remove(func_op.getOperation()); + symbol_table.insert(func_op.getOperation()); + } + } nested_module->erase(); }