Switch from explicit argument for inlining functions post TF -> HLO legalization to checking device type (NFC).

Currently TPU computations can have their functions inlined, to follow the behavior of inlining functions in the Graph based TPU bridge.

PiperOrigin-RevId: 346659159
Change-Id: I065e4defd8dc34794a739fa33f1f211abcc0301b
This commit is contained in:
Andy Ly 2020-12-09 16:08:58 -08:00 committed by TensorFlower Gardener
parent c184eb3ddc
commit b43da1eb10
6 changed files with 30 additions and 28 deletions

View File

@ -1670,15 +1670,16 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_helpers",
"//tensorflow/compiler/tf2xla:xla_argument",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/lib",
]
# Prefer to link 'compile_mlir_util' library that also links necessary

View File

@ -60,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
@ -266,13 +267,19 @@ static void RegisterDialects(mlir::DialectRegistry& registry) {
mlir::mhlo::registerAllMhloDialects(registry);
}
// Checks if functions can be inlined after TF -> HLO legalization. Currently
// TPU's are supported, to follow the behavior of inlining functions via the
// Graph based bridge in the TPUCompile op kernel.
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
return device_type == DEVICE_TPU_XLA_JIT;
}
} // namespace
void CreateConvertMlirToXlaHloPipeline(
mlir::OpPassManager& pm, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes,
bool inline_after_legalization) {
custom_legalization_passes) {
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
// Run shape inference pass before tensorlist decomposition to get buffer
@ -321,7 +328,8 @@ void CreateConvertMlirToXlaHloPipeline(
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
if (inline_after_legalization) pm.addPass(mlir::createInlinerPass());
if (CanInlineFunctionsPostLegalization(device_type))
pm.addPass(mlir::createInlinerPass());
// In order to export to XLA, we must sink constants to control flow regions,
// since XLA uses functional control flow.
@ -335,13 +343,11 @@ Status ConvertMLIRToXlaComputation(
bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes,
bool inline_after_legalization) {
custom_legalization_passes) {
mlir::PassManager tf2xla(module_op.getContext());
applyTensorflowAndCLOptions(tf2xla);
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
custom_legalization_passes,
inline_after_legalization);
custom_legalization_passes);
if (VLOG_IS_ON(1)) {
// Print the whole module after each pass which requires disabling
@ -379,8 +385,7 @@ Status CompileMlirToXlaHlo(
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes,
bool inline_after_legalization) {
custom_legalization_passes) {
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
@ -398,7 +403,7 @@ Status CompileMlirToXlaHlo(
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, device_type, compilation_result->computation.get(),
use_tuple_args, use_return_tuple, shape_representation_fn,
custom_legalization_passes, inline_after_legalization));
custom_legalization_passes));
// Construct mapping from XlaComputation's arg to input edges of execute
// node.
@ -441,8 +446,7 @@ Status CompileSerializedMlirToXlaHlo(
return CompileMlirToXlaHlo(
mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
/*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
shape_representation_fn, compilation_result, custom_legalization_passes,
/*inline_after_legalization=*/true);
shape_representation_fn, compilation_result, custom_legalization_passes);
}
// Rewrites the given module with specified args. For each of the constant args,
@ -543,8 +547,7 @@ Status CompileGraphToXlaHlo(
auto status = CompileMlirToXlaHlo(
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
/*use_resource_updates_for_aliases=*/true, shape_representation_fn,
compilation_result, custom_legalization_passes,
/*inline_after_legalization=*/false);
compilation_result, custom_legalization_passes);
compilation_result->input_mapping = remaining_params;
return status;
}

View File

@ -39,8 +39,7 @@ namespace tensorflow {
void CreateConvertMlirToXlaHloPipeline(
mlir::OpPassManager& pm, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes,
bool inline_after_legalization);
custom_legalization_passes);
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
// should only contain operations in tf dialect. If the input module contains
@ -74,8 +73,7 @@ Status ConvertMLIRToXlaComputation(
bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes = {},
bool inline_after_legalization = false);
custom_legalization_passes = {});
// Helper struct representing argument tensor or resource handle shapes.
struct TensorOrResourceShape {
@ -93,8 +91,7 @@ Status CompileMlirToXlaHlo(
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes,
bool inline_after_legalization);
custom_legalization_passes);
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
// metadata and stores them in CompilationResult.

View File

@ -21,7 +21,7 @@ namespace {
void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) {
tensorflow::CreateConvertMlirToXlaHloPipeline(
pm, /*device_type=*/"XLA_CPU_JIT",
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
/*custom_legalization_passes=*/{});
}
mlir::PassPipelineRegistration<> pipeline(

View File

@ -250,7 +250,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
IdentityShapeRepresentationFn(), &compilation_result,
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
/*custom_legalization_passes=*/{});
if (!compilation_status.ok()) {
LOG(ERROR) << "TF/XLA compilation failed: "
<< compilation_status.ToString();

View File

@ -8,6 +8,7 @@ load(
package(
default_visibility = [
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
"//tensorflow/core/tpu:__subpackages__",
"//tensorflow/stream_executor/tpu:__subpackages__",