Switch from explicit argument for inlining functions post TF -> HLO legalization to checking device type (NFC).
Currently TPU computations can have their functions inlined, to follow the behavior of inlining functions in the Graph based TPU bridge. PiperOrigin-RevId: 346659159 Change-Id: I065e4defd8dc34794a739fa33f1f211abcc0301b
This commit is contained in:
parent
c184eb3ddc
commit
b43da1eb10
@ -1670,15 +1670,16 @@ COMPILE_MLIR_UTIL_DEPS = [
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
"//tensorflow/compiler/tf2xla:xla_argument",
|
"//tensorflow/compiler/tf2xla:xla_argument",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
|
||||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/stream_executor/lib",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||||
|
@ -60,6 +60,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -266,13 +267,19 @@ static void RegisterDialects(mlir::DialectRegistry& registry) {
|
|||||||
mlir::mhlo::registerAllMhloDialects(registry);
|
mlir::mhlo::registerAllMhloDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks if functions can be inlined after TF -> HLO legalization. Currently
|
||||||
|
// TPU's are supported, to follow the behavior of inlining functions via the
|
||||||
|
// Graph based bridge in the TPUCompile op kernel.
|
||||||
|
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
|
||||||
|
return device_type == DEVICE_TPU_XLA_JIT;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void CreateConvertMlirToXlaHloPipeline(
|
void CreateConvertMlirToXlaHloPipeline(
|
||||||
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
custom_legalization_passes,
|
custom_legalization_passes) {
|
||||||
bool inline_after_legalization) {
|
|
||||||
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
|
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
|
||||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
// Run shape inference pass before tensorlist decomposition to get buffer
|
// Run shape inference pass before tensorlist decomposition to get buffer
|
||||||
@ -321,7 +328,8 @@ void CreateConvertMlirToXlaHloPipeline(
|
|||||||
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
|
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
|
||||||
/*tf2xla_fallback_device_type=*/device_type));
|
/*tf2xla_fallback_device_type=*/device_type));
|
||||||
|
|
||||||
if (inline_after_legalization) pm.addPass(mlir::createInlinerPass());
|
if (CanInlineFunctionsPostLegalization(device_type))
|
||||||
|
pm.addPass(mlir::createInlinerPass());
|
||||||
|
|
||||||
// In order to export to XLA, we must sink constants to control flow regions,
|
// In order to export to XLA, we must sink constants to control flow regions,
|
||||||
// since XLA uses functional control flow.
|
// since XLA uses functional control flow.
|
||||||
@ -335,13 +343,11 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
bool return_tuple,
|
bool return_tuple,
|
||||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
custom_legalization_passes,
|
custom_legalization_passes) {
|
||||||
bool inline_after_legalization) {
|
|
||||||
mlir::PassManager tf2xla(module_op.getContext());
|
mlir::PassManager tf2xla(module_op.getContext());
|
||||||
applyTensorflowAndCLOptions(tf2xla);
|
applyTensorflowAndCLOptions(tf2xla);
|
||||||
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
|
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
|
||||||
custom_legalization_passes,
|
custom_legalization_passes);
|
||||||
inline_after_legalization);
|
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
// Print the whole module after each pass which requires disabling
|
// Print the whole module after each pass which requires disabling
|
||||||
@ -379,8 +385,7 @@ Status CompileMlirToXlaHlo(
|
|||||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
custom_legalization_passes,
|
custom_legalization_passes) {
|
||||||
bool inline_after_legalization) {
|
|
||||||
if (VLOG_IS_ON(1))
|
if (VLOG_IS_ON(1))
|
||||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||||
|
|
||||||
@ -398,7 +403,7 @@ Status CompileMlirToXlaHlo(
|
|||||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||||
module_op, device_type, compilation_result->computation.get(),
|
module_op, device_type, compilation_result->computation.get(),
|
||||||
use_tuple_args, use_return_tuple, shape_representation_fn,
|
use_tuple_args, use_return_tuple, shape_representation_fn,
|
||||||
custom_legalization_passes, inline_after_legalization));
|
custom_legalization_passes));
|
||||||
|
|
||||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||||
// node.
|
// node.
|
||||||
@ -441,8 +446,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
return CompileMlirToXlaHlo(
|
return CompileMlirToXlaHlo(
|
||||||
mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
|
mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
|
||||||
/*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
|
/*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false,
|
||||||
shape_representation_fn, compilation_result, custom_legalization_passes,
|
shape_representation_fn, compilation_result, custom_legalization_passes);
|
||||||
/*inline_after_legalization=*/true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrites the given module with specified args. For each of the constant args,
|
// Rewrites the given module with specified args. For each of the constant args,
|
||||||
@ -543,8 +547,7 @@ Status CompileGraphToXlaHlo(
|
|||||||
auto status = CompileMlirToXlaHlo(
|
auto status = CompileMlirToXlaHlo(
|
||||||
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
|
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
|
||||||
/*use_resource_updates_for_aliases=*/true, shape_representation_fn,
|
/*use_resource_updates_for_aliases=*/true, shape_representation_fn,
|
||||||
compilation_result, custom_legalization_passes,
|
compilation_result, custom_legalization_passes);
|
||||||
/*inline_after_legalization=*/false);
|
|
||||||
compilation_result->input_mapping = remaining_params;
|
compilation_result->input_mapping = remaining_params;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -39,8 +39,7 @@ namespace tensorflow {
|
|||||||
void CreateConvertMlirToXlaHloPipeline(
|
void CreateConvertMlirToXlaHloPipeline(
|
||||||
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
custom_legalization_passes,
|
custom_legalization_passes);
|
||||||
bool inline_after_legalization);
|
|
||||||
|
|
||||||
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
|
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
|
||||||
// should only contain operations in tf dialect. If the input module contains
|
// should only contain operations in tf dialect. If the input module contains
|
||||||
@ -74,8 +73,7 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
bool return_tuple,
|
bool return_tuple,
|
||||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
custom_legalization_passes = {},
|
custom_legalization_passes = {});
|
||||||
bool inline_after_legalization = false);
|
|
||||||
|
|
||||||
// Helper struct representing argument tensor or resource handle shapes.
|
// Helper struct representing argument tensor or resource handle shapes.
|
||||||
struct TensorOrResourceShape {
|
struct TensorOrResourceShape {
|
||||||
@ -93,8 +91,7 @@ Status CompileMlirToXlaHlo(
|
|||||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
custom_legalization_passes,
|
custom_legalization_passes);
|
||||||
bool inline_after_legalization);
|
|
||||||
|
|
||||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||||
// metadata and stores them in CompilationResult.
|
// metadata and stores them in CompilationResult.
|
||||||
|
@ -21,7 +21,7 @@ namespace {
|
|||||||
void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) {
|
void CreateConvertMlirToXlaHloPipelineWithDefaults(mlir::OpPassManager& pm) {
|
||||||
tensorflow::CreateConvertMlirToXlaHloPipeline(
|
tensorflow::CreateConvertMlirToXlaHloPipeline(
|
||||||
pm, /*device_type=*/"XLA_CPU_JIT",
|
pm, /*device_type=*/"XLA_CPU_JIT",
|
||||||
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
|
/*custom_legalization_passes=*/{});
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::PassPipelineRegistration<> pipeline(
|
mlir::PassPipelineRegistration<> pipeline(
|
||||||
|
@ -250,7 +250,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
|
|||||||
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
|
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
|
||||||
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
|
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
|
||||||
IdentityShapeRepresentationFn(), &compilation_result,
|
IdentityShapeRepresentationFn(), &compilation_result,
|
||||||
/*custom_legalization_passes=*/{}, /*inline_after_legalization=*/false);
|
/*custom_legalization_passes=*/{});
|
||||||
if (!compilation_status.ok()) {
|
if (!compilation_status.ok()) {
|
||||||
LOG(ERROR) << "TF/XLA compilation failed: "
|
LOG(ERROR) << "TF/XLA compilation failed: "
|
||||||
<< compilation_status.ToString();
|
<< compilation_status.ToString();
|
||||||
|
@ -8,6 +8,7 @@ load(
|
|||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
|
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
|
||||||
"//tensorflow/core/tpu:__subpackages__",
|
"//tensorflow/core/tpu:__subpackages__",
|
||||||
"//tensorflow/stream_executor/tpu:__subpackages__",
|
"//tensorflow/stream_executor/tpu:__subpackages__",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user