Add CompileGraphToXlaBuilder to compile_mlir_util.h. Refactor internal functions to support both the new CompileGraphToXlaBuilder and the existing CompileGraphToHlo.
CompileGraphToXlaBuilder will be used for MLIR legalization in the old bridge. PiperOrigin-RevId: 346679170 Change-Id: I34b4e7b8b41ef3aa6a6a65426ecee4d113a97979
This commit is contained in:
parent
2090fe76f2
commit
c6293e9bfc
@ -1742,6 +1742,7 @@ cc_library(
|
||||
":translate_cl_options",
|
||||
"//tensorflow/compiler/mlir:string_container_utils",
|
||||
"//tensorflow/compiler/mlir/xla:translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/tf2xla:xla_argument",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -1755,7 +1756,6 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-return-tuple | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck -check-prefix=TUPLE-ARGS %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
|
||||
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
@ -36,3 +38,16 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
// TUPLE-ARGS-NEXT: // XlaInputShape (f32[], f32[])
|
||||
// TUPLE-ARGS-NEXT: // XlaOutputShape (f32[])
|
||||
// TUPLE-ARGS-NEXT: // XlaOutputDescription type=float shape=()
|
||||
|
||||
|
||||
// NO_RET_TUPLE-LABEL: HloModule main{{[.0-9]*}}
|
||||
// NO_RET_TUPLE: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
|
||||
// NO_RET_TUPLE-NEXT: %[[ARG0]] = f32[] parameter(0)
|
||||
// NO_RET_TUPLE-NEXT: %[[ARG1]] = f32[] parameter(1)
|
||||
// NO_RET_TUPLE-NEXT: ROOT [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
|
||||
|
||||
// NO_RET_TUPLE: // InputMapping {0, 1}
|
||||
// NO_RET_TUPLE-NEXT: // XlaInputShape f32[]
|
||||
// NO_RET_TUPLE-NEXT: // XlaInputShape f32[]
|
||||
// NO_RET_TUPLE-NEXT: // XlaOutputShape (f32[])
|
||||
// NO_RET_TUPLE-NEXT: // XlaOutputDescription type=float shape=()
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
|
||||
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
|
||||
@ -14,3 +16,9 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
// CHECK: [[CONSTANT:%.*]] = s32[0]{0} constant({})
|
||||
// CHECK: ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
|
||||
// CHECK: }
|
||||
|
||||
// NO_TUPLES-LABEL: HloModule main{{.[0-9+]}}
|
||||
// NO_TUPLES: ENTRY %main.{{[0-9+]}} () -> (s32[0], s32[0]) {
|
||||
// NO_TUPLES: [[CONSTANT:%.*]] = s32[0]{0} constant({})
|
||||
// NO_TUPLES: ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
|
||||
// NO_TUPLES: }
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=8,16,16,64:64 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=8,16,16,64:64 | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) {
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
|
||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
|
||||
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
|
||||
@ -9,3 +11,6 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
|
||||
// CHECK-LABEL: HloModule main
|
||||
// CHECK: (arg_tuple.{{[0-9]+}}: (f32[10,17], f32[17,19])) -> (f32[10,19])
|
||||
|
||||
// NO_TUPLES-LABEL: HloModule main{{.[0-9]*}}
|
||||
// NO_TUPLES: ({{.+}}: f32[10,17], {{.+}}: f32[17,19]) -> f32[10,19]
|
||||
|
||||
@ -211,7 +211,20 @@ void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
|
||||
std::iota(input_mapping->begin(), input_mapping->end(), 0);
|
||||
}
|
||||
|
||||
// Refine MLIR types based on new shape information.
|
||||
static void RegisterDialects(mlir::DialectRegistry& registry) {
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
mlir::mhlo::registerAllMhloDialects(registry);
|
||||
}
|
||||
|
||||
// Checks if functions can be inlined after TF -> HLO legalization. Currently
|
||||
// TPU's are supported, to follow the behavior of inlining functions via the
|
||||
// Graph based bridge in the TPUCompile op kernel.
|
||||
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
|
||||
return device_type == DEVICE_TPU_XLA_JIT;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
mlir::ModuleOp module) {
|
||||
auto producer_or = GetTfGraphProducerVersion(module);
|
||||
@ -262,20 +275,6 @@ Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void RegisterDialects(mlir::DialectRegistry& registry) {
|
||||
mlir::RegisterAllTensorFlowDialects(registry);
|
||||
mlir::mhlo::registerAllMhloDialects(registry);
|
||||
}
|
||||
|
||||
// Checks if functions can be inlined after TF -> HLO legalization. Currently
|
||||
// TPU's are supported, to follow the behavior of inlining functions via the
|
||||
// Graph based bridge in the TPUCompile op kernel.
|
||||
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
|
||||
return device_type == DEVICE_TPU_XLA_JIT;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CreateConvertMlirToXlaHloPipeline(
|
||||
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
@ -337,11 +336,7 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
mlir::mhlo::createSinkConstantsToControlFlowPass());
|
||||
}
|
||||
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
@ -370,6 +365,32 @@ Status ConvertMLIRToXlaComputation(
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_legalize_hlo", module_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
LegalizeToHlo(module_op, device_type, custom_legalization_passes));
|
||||
|
||||
mlir::Block& block = module_op.lookupSymbol<mlir::FuncOp>("main").front();
|
||||
return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
|
||||
}
|
||||
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
LegalizeToHlo(module_op, device_type, custom_legalization_passes));
|
||||
|
||||
xla::HloProto hlo_proto;
|
||||
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
|
||||
use_tuple_args, return_tuple,
|
||||
@ -378,14 +399,9 @@ Status ConvertMLIRToXlaComputation(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CompileMlirToXlaHlo(
|
||||
Status CompileMlirSetup(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
|
||||
bool use_resource_updates_for_aliases,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
XlaHelpers::ShapeRepresentationFn* shape_representation_fn) {
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||
|
||||
@ -395,16 +411,39 @@ Status CompileMlirToXlaHlo(
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
|
||||
|
||||
if (!shape_representation_fn)
|
||||
shape_representation_fn = IdentityShapeRepresentationFn();
|
||||
if (!*shape_representation_fn)
|
||||
*shape_representation_fn = IdentityShapeRepresentationFn();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
|
||||
|
||||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, device_type, compilation_result->computation.get(),
|
||||
use_tuple_args, use_return_tuple, shape_representation_fn,
|
||||
TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
|
||||
returns, device_type,
|
||||
custom_legalization_passes));
|
||||
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PopulateResultIOInfo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
bool use_tuple_args, bool use_resource_updates_for_aliases,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result) {
|
||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||
// node.
|
||||
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
|
||||
@ -426,6 +465,29 @@ Status CompileMlirToXlaHlo(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CompileMlirToXlaHlo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
|
||||
bool use_resource_updates_for_aliases,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
|
||||
|
||||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, device_type, compilation_result->computation.get(),
|
||||
use_tuple_args, use_return_tuple, shape_representation_fn,
|
||||
custom_legalization_passes));
|
||||
|
||||
return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
|
||||
use_resource_updates_for_aliases,
|
||||
shape_representation_fn, compilation_result);
|
||||
}
|
||||
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
@ -516,18 +578,13 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
return params;
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
Status CompileGraphSetup(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
|
||||
RewriteWithArgs(module_op, args));
|
||||
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
|
||||
arg_shapes.reserve(remaining_params.size());
|
||||
for (unsigned idx : remaining_params) {
|
||||
std::vector<int>* remaining_params,
|
||||
llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
|
||||
TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
|
||||
arg_shapes.reserve(remaining_params->size());
|
||||
for (unsigned idx : *remaining_params) {
|
||||
const auto& arg = args[idx];
|
||||
TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
|
||||
GetTensorShapeFromXlaArgument(arg));
|
||||
@ -539,10 +596,39 @@ Status CompileGraphToXlaHlo(
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
mlir::TF::StandardPipelineOptions tf_options;
|
||||
mlir::TF::CreateTFStandardPipeline(pm, tf_options);
|
||||
{
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
|
||||
if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
std::vector<int> remaining_params;
|
||||
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
|
||||
return BuildHloFromTf(module_op, builder, xla_params, returns, arg_shapes,
|
||||
device_type, custom_legalization_passes);
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
std::vector<int> remaining_params;
|
||||
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
|
||||
|
||||
auto status = CompileMlirToXlaHlo(
|
||||
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
|
||||
@ -552,6 +638,49 @@ Status CompileGraphToXlaHlo(
|
||||
return status;
|
||||
}
|
||||
|
||||
Status GraphToModule(const Graph& graph,
|
||||
llvm::ArrayRef<std::string> control_rets,
|
||||
const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
mlir::MLIRContext* context,
|
||||
mlir::OwningModuleRef* module) {
|
||||
RegisterDialects(context->getDialectRegistry());
|
||||
GraphImportConfig config;
|
||||
config.graph_as_function = true;
|
||||
config.control_outputs = control_rets;
|
||||
// Disable shape inference during import as some TensorFlow op fails during
|
||||
// shape inference with dynamic shaped operands. This in turn causes the
|
||||
// import to fail. Shape inference during import is going to be removed and
|
||||
// the shape inference pass is run early in the pass pipeline, shape inference
|
||||
// during import is not necessary.
|
||||
config.enable_shape_inference = false;
|
||||
auto module_or =
|
||||
ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
|
||||
if (!module_or.ok()) return module_or.status();
|
||||
|
||||
*module = std::move(module_or.ValueOrDie());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::ArrayRef<std::string> control_rets,
|
||||
llvm::StringRef device_type,
|
||||
const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
|
||||
&context, &module));
|
||||
return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
|
||||
device_type, custom_legalization_passes);
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
|
||||
@ -562,22 +691,10 @@ Status CompileGraphToXlaHlo(
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
mlir::MLIRContext context;
|
||||
RegisterDialects(context.getDialectRegistry());
|
||||
GraphImportConfig config;
|
||||
config.graph_as_function = true;
|
||||
config.control_outputs = control_rets;
|
||||
// Disable shape inference during import as some TensorFlow op fails during
|
||||
// shape inference with dynamic shaped operands. This in turn causes the
|
||||
// import to fail. Shape inference during import is going to be removed and
|
||||
// the shape inference pass is run early in the pass pipeline, shape inference
|
||||
// during import is not necessary.
|
||||
config.enable_shape_inference = false;
|
||||
auto module_or =
|
||||
ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
|
||||
if (!module_or.ok()) return module_or.status();
|
||||
|
||||
mlir::ModuleOp module_op = module_or.ValueOrDie().get();
|
||||
return CompileGraphToXlaHlo(module_op, args, device_type, use_tuple_args,
|
||||
mlir::OwningModuleRef module;
|
||||
TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
|
||||
&context, &module));
|
||||
return CompileGraphToXlaHlo(module.get(), args, device_type, use_tuple_args,
|
||||
/*use_return_tuple=*/true,
|
||||
shape_representation_fn, compilation_result,
|
||||
custom_legalization_passes);
|
||||
|
||||
@ -81,6 +81,30 @@ struct TensorOrResourceShape {
|
||||
bool is_resource = false;
|
||||
};
|
||||
|
||||
// Refine MLIR types based on new shape information.
|
||||
Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
mlir::ModuleOp module);
|
||||
|
||||
// Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
|
||||
// inputs to module_op that have already been added to the XlaBuilder. returns
|
||||
// are the returned XlaOps.
|
||||
Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes);
|
||||
|
||||
// Apply shape, description, and resource information to inputs and outputs
|
||||
// in the XlaCompilationResult. This should be called after
|
||||
// compilation_result->computation was set.
|
||||
Status PopulateResultIOInfo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
bool use_tuple_args, bool use_resource_updates_for_aliases,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result);
|
||||
|
||||
// Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
|
||||
// stores them in CompilationResult.
|
||||
// TODO(hinsu): Migrate options to separate struct.
|
||||
@ -128,6 +152,21 @@ Status CompileGraphToXlaHlo(
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes = {});
|
||||
|
||||
// Compiles a Graph from TF to HLO and adds the resulting HLO to the
|
||||
// XlaBuilder. This function adds HLO to a larger HLO computation, so
|
||||
// HLO-level inputs are supplied, and HLO-level outputs are produced.
|
||||
// xla_params is the HLO-level inputs and returns is the HLO-level outputs.
|
||||
Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::ArrayRef<std::string> control_rets,
|
||||
llvm::StringRef device_type,
|
||||
const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes = {});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
|
||||
|
||||
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
|
||||
#include "tensorflow/compiler/mlir/utils/string_container_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -233,8 +234,64 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
|
||||
mlir::ModuleOp module_op, llvm::raw_ostream& output) {
|
||||
// Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so
|
||||
// to make this test comparable to other compile tests, the test implements
|
||||
// the remaining parts of the conversion.
|
||||
Status CompileMlirToXlaHloViaBuilder(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
|
||||
llvm::StringRef device_type, XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
// This call to RefineShapes is redundant with the call in BuildHloFromTf.
|
||||
// It's here so xla::Parameters that are created form block.getArguments will
|
||||
// have the proper shapes.
|
||||
TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
|
||||
|
||||
mlir::FuncOp main = module_op.lookupSymbol<mlir::FuncOp>("main");
|
||||
mlir::Block& block = main.getRegion().front();
|
||||
xla::XlaBuilder builder("main");
|
||||
|
||||
// Create xla_params.
|
||||
std::vector<xla::XlaOp> xla_params;
|
||||
for (mlir::BlockArgument& arg : block.getArguments()) {
|
||||
auto num = arg.getArgNumber();
|
||||
xla::Shape shape = xla::TypeToShape(arg.getType());
|
||||
xla::XlaOp argop =
|
||||
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
|
||||
xla_params.push_back(argop);
|
||||
}
|
||||
|
||||
std::vector<xla::XlaOp> returns(1);
|
||||
TF_RETURN_IF_ERROR(BuildHloFromTf(module_op, builder, xla_params, returns,
|
||||
arg_shapes, device_type,
|
||||
custom_legalization_passes));
|
||||
|
||||
xla::XlaOp return_value;
|
||||
if (returns.size() == 1)
|
||||
return_value = returns[0];
|
||||
else
|
||||
return_value = xla::Tuple(&builder, returns);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::XlaComputation computation,
|
||||
return_value.valid() ? builder.Build(return_value) : builder.Build());
|
||||
auto hlo_module = computation.proto();
|
||||
xla::HloProto hlo_proto;
|
||||
hlo_proto.mutable_hlo_module()->Swap(&hlo_module);
|
||||
|
||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||
xla::XlaComputation* xla_computation = compilation_result->computation.get();
|
||||
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
|
||||
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn =
|
||||
IdentityShapeRepresentationFn();
|
||||
return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false,
|
||||
/*use_resource_updates_for_aliases=*/false,
|
||||
shape_representation_fn, compilation_result);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl(
|
||||
mlir::ModuleOp module_op, llvm::raw_ostream& output, bool via_builder) {
|
||||
if (!module_op) return mlir::failure();
|
||||
|
||||
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
|
||||
@ -245,12 +302,21 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto device_type = "XLA_CPU_JIT";
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes{};
|
||||
XlaCompilationResult compilation_result;
|
||||
auto compilation_status = CompileMlirToXlaHlo(
|
||||
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
|
||||
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
|
||||
IdentityShapeRepresentationFn(), &compilation_result,
|
||||
/*custom_legalization_passes=*/{});
|
||||
auto compilation_status =
|
||||
via_builder
|
||||
? CompileMlirToXlaHloViaBuilder(module_op, arg_shapes, device_type,
|
||||
&compilation_result,
|
||||
custom_legalization_passes)
|
||||
: CompileMlirToXlaHlo(module_op, arg_shapes, device_type,
|
||||
emit_use_tuple_arg, emit_return_tuple,
|
||||
/*use_resource_updates_for_aliases=*/true,
|
||||
IdentityShapeRepresentationFn(),
|
||||
&compilation_result,
|
||||
custom_legalization_passes);
|
||||
if (!compilation_status.ok()) {
|
||||
LOG(ERROR) << "TF/XLA compilation failed: "
|
||||
<< compilation_status.ToString();
|
||||
@ -326,12 +392,27 @@ static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate(
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
|
||||
mlir::ModuleOp module_op, llvm::raw_ostream& output) {
|
||||
return MlirTfToHloTextTranslateFunctionImpl(module_op, output, false);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirTfToHloTextViaBuilderTranslateFunction(
|
||||
mlir::ModuleOp module_op, llvm::raw_ostream& output) {
|
||||
return MlirTfToHloTextTranslateFunctionImpl(module_op, output, true);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate(
|
||||
"mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction,
|
||||
tensorflow::RegisterMlirInputDialects);
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirTfToHloTextViaBuilderTranslate(
|
||||
"mlir-tf-to-hlo-text-via-builder",
|
||||
tensorflow::MlirTfToHloTextViaBuilderTranslateFunction,
|
||||
tensorflow::RegisterMlirInputDialects);
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate(
|
||||
"mlir-tf-graph-to-hlo-text",
|
||||
tensorflow::MlirTfGraphToHloTextTranslateFunction,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user