diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a2c40453c91..7f98e3fda94 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1670,15 +1670,16 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/tf2xla:xla_argument", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core/common_runtime:core_cpu_internal", - "//tensorflow/core/platform:logging", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/stream_executor/lib", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/platform:logging", + "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/stream_executor/lib", ] # Prefer to link 'compile_mlir_util' library that also links necessary diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 2317eaf427e..91d22bd0789 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -60,6 +60,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/tpu/tpu_defs.h" namespace tensorflow { namespace { @@ -266,13 +267,19 @@ static void RegisterDialects(mlir::DialectRegistry& registry) { mlir::mhlo::registerAllMhloDialects(registry); } +// Checks if functions can be inlined after TF -> HLO legalization. Currently +// TPU's are supported, to follow the behavior of inlining functions via the +// Graph based bridge in the TPUCompile op kernel. +bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) { + return device_type == DEVICE_TPU_XLA_JIT; +} + } // namespace void CreateConvertMlirToXlaHloPipeline( mlir::OpPassManager& pm, llvm::StringRef device_type, llvm::MutableArrayRef> - custom_legalization_passes, - bool inline_after_legalization) { + custom_legalization_passes) { pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addNestedPass(mlir::createCanonicalizerPass()); // Run shape inference pass before tensorlist decomposition to get buffer @@ -321,7 +328,8 @@ void CreateConvertMlirToXlaHloPipeline( /*allow_partial_conversion=*/false, /*legalize_chlo=*/true, /*tf2xla_fallback_device_type=*/device_type)); - if (inline_after_legalization) pm.addPass(mlir::createInlinerPass()); + if (CanInlineFunctionsPostLegalization(device_type)) + pm.addPass(mlir::createInlinerPass()); // In order to export to XLA, we must sink constants to control flow regions, // since XLA uses functional control flow. @@ -335,13 +343,11 @@ Status ConvertMLIRToXlaComputation( bool return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, llvm::MutableArrayRef> - custom_legalization_passes, - bool inline_after_legalization) { + custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); applyTensorflowAndCLOptions(tf2xla); CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, - custom_legalization_passes, - inline_after_legalization); + custom_legalization_passes); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling @@ -379,8 +385,7 @@ Status CompileMlirToXlaHlo( XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> - custom_legalization_passes, - bool inline_after_legalization) { + custom_legalization_passes) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -398,7 +403,7 @@ Status CompileMlirToXlaHlo( TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, device_type, compilation_result->computation.get(), use_tuple_args, use_return_tuple, shape_representation_fn, - custom_legalization_passes, inline_after_legalization)); + custom_legalization_passes)); // Construct mapping from XlaComputation's arg to input edges of execute // node. @@ -441,8 +446,7 @@ Status CompileSerializedMlirToXlaHlo( return CompileMlirToXlaHlo( mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args, /*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false, - shape_representation_fn, compilation_result, custom_legalization_passes, - /*inline_after_legalization=*/true); + shape_representation_fn, compilation_result, custom_legalization_passes); } // Rewrites the given module with specified args. For each of the constant args, @@ -543,8 +547,7 @@ Status CompileGraphToXlaHlo( auto status = CompileMlirToXlaHlo( module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple, /*use_resource_updates_for_aliases=*/true, shape_representation_fn, - compilation_result, custom_legalization_passes, - /*inline_after_legalization=*/false); + compilation_result, custom_legalization_passes); compilation_result->input_mapping = remaining_params; return status; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 102f4e5aa0d..71c672829fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -39,8 +39,7 @@ namespace tensorflow { void CreateConvertMlirToXlaHloPipeline( mlir::OpPassManager& pm, llvm::StringRef device_type, llvm::MutableArrayRef> - custom_legalization_passes, - bool inline_after_legalization); + custom_legalization_passes); // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains @@ -74,8 +73,7 @@ Status ConvertMLIRToXlaComputation( bool return_tuple, const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, llvm::MutableArrayRef> - custom_legalization_passes = {}, - bool inline_after_legalization = false); + custom_legalization_passes = {}); // Helper struct representing argument tensor or resource handle shapes. struct TensorOrResourceShape { @@ -93,8 +91,7 @@ Status CompileMlirToXlaHlo( XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> - custom_legalization_passes, - bool inline_after_legalization); + custom_legalization_passes); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc index 1ebed7cb811..57267ff027f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_pass.cc @@ -21,7 +21,7 @@ namespace { void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) { tensorflow::CreateConvertMlirToXlaHloPipeline( pm, /*device_type=*/"XLA_CPU_JIT", - /*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false); + /*custom_legalization_passes=*/{}); } mlir::PassPipelineRegistration<> pipeline( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index dbc8ae06390..f6cf5a4120f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -250,7 +250,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction( module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg, emit_return_tuple, /*use_resource_updates_for_aliases=*/true, IdentityShapeRepresentationFn(), &compilation_result, - /*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false); + /*custom_legalization_passes=*/{}); if (!compilation_status.ok()) { LOG(ERROR) << "TF/XLA compilation failed: " << compilation_status.ToString(); diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 8264ec393cb..1235821e80c 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -8,6 +8,7 @@ load( package( default_visibility = [ + "//tensorflow/compiler/mlir/tensorflow:__subpackages__", "//tensorflow/compiler/tf2xla/kernels:__subpackages__", "//tensorflow/core/tpu:__subpackages__", "//tensorflow/stream_executor/tpu:__subpackages__",