Switch from explicit argument for inlining functions post TF -> HLO legalization to checking device type (NFC).

Currently TPU computations can have their functions inlined, to follow the behavior of inlining functions in the Graph based TPU bridge.

PiperOrigin-RevId: 346659159
Change-Id: I065e4defd8dc34794a739fa33f1f211abcc0301b
This commit is contained in:
Andy Ly 2020-12-09 16:08:58 -08:00 committed by TensorFlower Gardener
parent c184eb3ddc
commit b43da1eb10
6 changed files with 30 additions and 28 deletions

View File

@ -1670,15 +1670,16 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/tf2xla:xla_helpers",
"//tensorflow/compiler/tf2xla:xla_argument", "//tensorflow/compiler/tf2xla:xla_argument",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/lib",
] ]
# Prefer to link 'compile_mlir_util' library that also links necessary # Prefer to link 'compile_mlir_util' library that also links necessary

View File

@ -60,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -266,13 +267,19 @@ static void RegisterDialects(mlir::DialectRegistry& registry) {
mlir::mhlo::registerAllMhloDialects(registry); mlir::mhlo::registerAllMhloDialects(registry);
} }
// Checks if functions can be inlined after TF -> HLO legalization. Currently
// TPU's are supported, to follow the behavior of inlining functions via the
// Graph based bridge in the TPUCompile op kernel.
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
return device_type == DEVICE_TPU_XLA_JIT;
}
} // namespace } // namespace
void CreateConvertMlirToXlaHloPipeline( void CreateConvertMlirToXlaHloPipeline(
mlir::OpPassManager& pm, llvm::StringRef device_type, mlir::OpPassManager& pm, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes, custom_legalization_passes) {
bool inline_after_legalization) {
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
// Run shape inference pass before tensorlist decomposition to get buffer // Run shape inference pass before tensorlist decomposition to get buffer
@ -321,7 +328,8 @@ void CreateConvertMlirToXlaHloPipeline(
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true, /*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type)); /*tf2xla_fallback_device_type=*/device_type));
if (inline_after_legalization) pm.addPass(mlir::createInlinerPass()); if (CanInlineFunctionsPostLegalization(device_type))
pm.addPass(mlir::createInlinerPass());
// In order to export to XLA, we must sink constants to control flow regions, // In order to export to XLA, we must sink constants to control flow regions,
// since XLA uses functional control flow. // since XLA uses functional control flow.
@ -335,13 +343,11 @@ Status ConvertMLIRToXlaComputation(
bool return_tuple, bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn, const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes, custom_legalization_passes) {
bool inline_after_legalization) {
mlir::PassManager tf2xla(module_op.getContext()); mlir::PassManager tf2xla(module_op.getContext());
applyTensorflowAndCLOptions(tf2xla); applyTensorflowAndCLOptions(tf2xla);
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
custom_legalization_passes, custom_legalization_passes);
inline_after_legalization);
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
// Print the whole module after each pass which requires disabling // Print the whole module after each pass which requires disabling
@ -379,8 +385,7 @@ Status CompileMlirToXlaHlo(
XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result, XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes, custom_legalization_passes) {
bool inline_after_legalization) {
if (VLOG_IS_ON(1)) if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
@ -398,7 +403,7 @@ Status CompileMlirToXlaHlo(
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, device_type, compilation_result->computation.get(), module_op, device_type, compilation_result->computation.get(),
use_tuple_args, use_return_tuple, shape_representation_fn, use_tuple_args, use_return_tuple, shape_representation_fn,
custom_legalization_passes, inline_after_legalization)); custom_legalization_passes));
// Construct mapping from XlaComputation's arg to input edges of execute // Construct mapping from XlaComputation's arg to input edges of execute
// node. // node.
@ -441,8 +446,7 @@ Status CompileSerializedMlirToXlaHlo(
return CompileMlirToXlaHlo( return CompileMlirToXlaHlo(
mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args, mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
/*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false, /*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
shape_representation_fn, compilation_result, custom_legalization_passes, shape_representation_fn, compilation_result, custom_legalization_passes);
/*inline_after_legalization=*/true);
} }
// Rewrites the given module with specified args. For each of the constant args, // Rewrites the given module with specified args. For each of the constant args,
@ -543,8 +547,7 @@ Status CompileGraphToXlaHlo(
auto status = CompileMlirToXlaHlo( auto status = CompileMlirToXlaHlo(
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple, module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
/*use_resource_updates_for_aliases=*/true, shape_representation_fn, /*use_resource_updates_for_aliases=*/true, shape_representation_fn,
compilation_result, custom_legalization_passes, compilation_result, custom_legalization_passes);
/*inline_after_legalization=*/false);
compilation_result->input_mapping = remaining_params; compilation_result->input_mapping = remaining_params;
return status; return status;
} }

View File

@ -39,8 +39,7 @@ namespace tensorflow {
void CreateConvertMlirToXlaHloPipeline( void CreateConvertMlirToXlaHloPipeline(
mlir::OpPassManager& pm, llvm::StringRef device_type, mlir::OpPassManager& pm, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes, custom_legalization_passes);
bool inline_after_legalization);
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
// should only contain operations in tf dialect. If the input module contains // should only contain operations in tf dialect. If the input module contains
@ -74,8 +73,7 @@ Status ConvertMLIRToXlaComputation(
bool return_tuple, bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes = {}, custom_legalization_passes = {});
bool inline_after_legalization = false);
// Helper struct representing argument tensor or resource handle shapes. // Helper struct representing argument tensor or resource handle shapes.
struct TensorOrResourceShape { struct TensorOrResourceShape {
@ -93,8 +91,7 @@ Status CompileMlirToXlaHlo(
XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result, XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes, custom_legalization_passes);
bool inline_after_legalization);
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying // Compiles a serialized MLIR module into XLA HLO, generates all accompanying
// metadata and stores them in CompilationResult. // metadata and stores them in CompilationResult.

View File

@ -21,7 +21,7 @@ namespace {
void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) { void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) {
tensorflow::CreateConvertMlirToXlaHloPipeline( tensorflow::CreateConvertMlirToXlaHloPipeline(
pm, /*device_type=*/"XLA_CPU_JIT", pm, /*device_type=*/"XLA_CPU_JIT",
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false); /*custom_legalization_passes=*/{});
} }
mlir::PassPipelineRegistration<> pipeline( mlir::PassPipelineRegistration<> pipeline(

View File

@ -250,7 +250,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg, module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
emit_return_tuple, /*use_resource_updates_for_aliases=*/true, emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
IdentityShapeRepresentationFn(), &compilation_result, IdentityShapeRepresentationFn(), &compilation_result,
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false); /*custom_legalization_passes=*/{});
if (!compilation_status.ok()) { if (!compilation_status.ok()) {
LOG(ERROR) << "TF/XLA compilation failed: " LOG(ERROR) << "TF/XLA compilation failed: "
<< compilation_status.ToString(); << compilation_status.ToString();

View File

@ -8,6 +8,7 @@ load(
package( package(
default_visibility = [ default_visibility = [
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
"//tensorflow/compiler/tf2xla/kernels:__subpackages__", "//tensorflow/compiler/tf2xla/kernels:__subpackages__",
"//tensorflow/core/tpu:__subpackages__", "//tensorflow/core/tpu:__subpackages__",
"//tensorflow/stream_executor/tpu:__subpackages__", "//tensorflow/stream_executor/tpu:__subpackages__",