Switch from explicit argument for inlining functions post TF -> HLO legalization to checking device type (NFC).
Currently TPU computations can have their functions inlined, to follow the behavior of inlining functions in the Graph based TPU bridge. PiperOrigin-RevId: 346659159 Change-Id: I065e4defd8dc34794a739fa33f1f211abcc0301b
This commit is contained in:
parent
c184eb3ddc
commit
b43da1eb10
tensorflow
compiler/mlir/tensorflow
core/tpu
@ -1670,15 +1670,16 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_argument",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
]
|
||||
|
||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||
|
@ -60,6 +60,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -266,13 +267,19 @@ static void RegisterDialects(mlir::DialectRegistry& registry) {
|
||||
mlir::mhlo::registerAllMhloDialects(registry);
|
||||
}
|
||||
|
||||
// Checks if functions can be inlined after TF -> HLO legalization. Currently
|
||||
// TPU's are supported, to follow the behavior of inlining functions via the
|
||||
// Graph based bridge in the TPUCompile op kernel.
|
||||
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
|
||||
return device_type == DEVICE_TPU_XLA_JIT;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CreateConvertMlirToXlaHloPipeline(
|
||||
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes,
|
||||
bool inline_after_legalization) {
|
||||
custom_legalization_passes) {
|
||||
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
// Run shape inference pass before tensorlist decomposition to get buffer
|
||||
@ -321,7 +328,8 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
|
||||
/*tf2xla_fallback_device_type=*/device_type));
|
||||
|
||||
if (inline_after_legalization) pm.addPass(mlir::createInlinerPass());
|
||||
if (CanInlineFunctionsPostLegalization(device_type))
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
|
||||
// In order to export to XLA, we must sink constants to control flow regions,
|
||||
// since XLA uses functional control flow.
|
||||
@ -335,13 +343,11 @@ Status ConvertMLIRToXlaComputation(
|
||||
bool return_tuple,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes,
|
||||
bool inline_after_legalization) {
|
||||
custom_legalization_passes) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
applyTensorflowAndCLOptions(tf2xla);
|
||||
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
|
||||
custom_legalization_passes,
|
||||
inline_after_legalization);
|
||||
custom_legalization_passes);
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
// Print the whole module after each pass which requires disabling
|
||||
@ -379,8 +385,7 @@ Status CompileMlirToXlaHlo(
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes,
|
||||
bool inline_after_legalization) {
|
||||
custom_legalization_passes) {
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||
|
||||
@ -398,7 +403,7 @@ Status CompileMlirToXlaHlo(
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, device_type, compilation_result->computation.get(),
|
||||
use_tuple_args, use_return_tuple, shape_representation_fn,
|
||||
custom_legalization_passes, inline_after_legalization));
|
||||
custom_legalization_passes));
|
||||
|
||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||
// node.
|
||||
@ -441,8 +446,7 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
return CompileMlirToXlaHlo(
|
||||
mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
|
||||
/*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
|
||||
shape_representation_fn, compilation_result, custom_legalization_passes,
|
||||
/*inline_after_legalization=*/true);
|
||||
shape_representation_fn, compilation_result, custom_legalization_passes);
|
||||
}
|
||||
|
||||
// Rewrites the given module with specified args. For each of the constant args,
|
||||
@ -543,8 +547,7 @@ Status CompileGraphToXlaHlo(
|
||||
auto status = CompileMlirToXlaHlo(
|
||||
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
|
||||
/*use_resource_updates_for_aliases=*/true, shape_representation_fn,
|
||||
compilation_result, custom_legalization_passes,
|
||||
/*inline_after_legalization=*/false);
|
||||
compilation_result, custom_legalization_passes);
|
||||
compilation_result->input_mapping = remaining_params;
|
||||
return status;
|
||||
}
|
||||
|
@ -39,8 +39,7 @@ namespace tensorflow {
|
||||
void CreateConvertMlirToXlaHloPipeline(
|
||||
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes,
|
||||
bool inline_after_legalization);
|
||||
custom_legalization_passes);
|
||||
|
||||
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
|
||||
// should only contain operations in tf dialect. If the input module contains
|
||||
@ -74,8 +73,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
bool return_tuple,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes = {},
|
||||
bool inline_after_legalization = false);
|
||||
custom_legalization_passes = {});
|
||||
|
||||
// Helper struct representing argument tensor or resource handle shapes.
|
||||
struct TensorOrResourceShape {
|
||||
@ -93,8 +91,7 @@ Status CompileMlirToXlaHlo(
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes,
|
||||
bool inline_after_legalization);
|
||||
custom_legalization_passes);
|
||||
|
||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||
// metadata and stores them in CompilationResult.
|
||||
|
@ -21,7 +21,7 @@ namespace {
|
||||
void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) {
|
||||
tensorflow::CreateConvertMlirToXlaHloPipeline(
|
||||
pm, /*device_type=*/"XLA_CPU_JIT",
|
||||
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
|
||||
/*custom_legalization_passes=*/{});
|
||||
}
|
||||
|
||||
mlir::PassPipelineRegistration<> pipeline(
|
||||
|
@ -250,7 +250,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
|
||||
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
|
||||
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
|
||||
IdentityShapeRepresentationFn(), &compilation_result,
|
||||
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
|
||||
/*custom_legalization_passes=*/{});
|
||||
if (!compilation_status.ok()) {
|
||||
LOG(ERROR) << "TF/XLA compilation failed: "
|
||||
<< compilation_status.ToString();
|
||||
|
@ -8,6 +8,7 @@ load(
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
|
||||
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
"//tensorflow/stream_executor/tpu:__subpackages__",
|
||||
|
Loading…
Reference in New Issue
Block a user