Add CompileGraphToXlaBuilder to compile_mlir_util.h. Refactor internal functions to support both the new CompileGraphToXlaBuilder and the existing CompileGraphToHlo.

CompileGraphToXlaBuilder will be used for MLIR legalization in the old bridge.

PiperOrigin-RevId: 346679170
Change-Id: I34b4e7b8b41ef3aa6a6a65426ecee4d113a97979
This commit is contained in:
Michael Delorimier 2020-12-09 18:01:35 -08:00 committed by TensorFlower Gardener
parent 2090fe76f2
commit c6293e9bfc
8 changed files with 341 additions and 75 deletions

View File

@ -1742,6 +1742,7 @@ cc_library(
":translate_cl_options",
"//tensorflow/compiler/mlir:string_container_utils",
"//tensorflow/compiler/mlir/xla:translate_cl_options",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/tf2xla:xla_argument",
"//tensorflow/compiler/tf2xla:xla_helpers",
"//tensorflow/compiler/xla/service:hlo",
@ -1755,7 +1756,6 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",

View File

@ -1,5 +1,7 @@
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-return-tuple | FileCheck %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck -check-prefix=TUPLE-ARGS %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
@ -36,3 +38,16 @@ module attributes {tf.versions = {producer = 179 : i32}} {
// TUPLE-ARGS-NEXT: // XlaInputShape (f32[], f32[])
// TUPLE-ARGS-NEXT: // XlaOutputShape (f32[])
// TUPLE-ARGS-NEXT: // XlaOutputDescription type=float shape=()
// NO_RET_TUPLE-LABEL: HloModule main{{[.0-9]*}}
// NO_RET_TUPLE: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
// NO_RET_TUPLE-NEXT: %[[ARG0]] = f32[] parameter(0)
// NO_RET_TUPLE-NEXT: %[[ARG1]] = f32[] parameter(1)
// NO_RET_TUPLE-NEXT: ROOT [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
// NO_RET_TUPLE: // InputMapping {0, 1}
// NO_RET_TUPLE-NEXT: // XlaInputShape f32[]
// NO_RET_TUPLE-NEXT: // XlaInputShape f32[]
// NO_RET_TUPLE-NEXT: // XlaOutputShape (f32[])
// NO_RET_TUPLE-NEXT: // XlaOutputDescription type=float shape=()

View File

@ -1,4 +1,6 @@
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
module attributes {tf.versions = {producer = 179 : i32}} {
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
@ -14,3 +16,9 @@ module attributes {tf.versions = {producer = 179 : i32}} {
// CHECK: [[CONSTANT:%.*]] = s32[0]{0} constant({})
// CHECK: ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
// CHECK: }
// NO_TUPLES-LABEL: HloModule main{{.[0-9+]}}
// NO_TUPLES: ENTRY %main.{{[0-9+]}} () -> (s32[0], s32[0]) {
// NO_TUPLES: [[CONSTANT:%.*]] = s32[0]{0} constant({})
// NO_TUPLES: ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
// NO_TUPLES: }

View File

@ -1,4 +1,5 @@
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=8,16,16,64:64 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=8,16,16,64:64 | FileCheck %s
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) {

View File

@ -1,4 +1,6 @@
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
@ -9,3 +11,6 @@ module attributes {tf.versions = {producer = 179 : i32}} {
// CHECK-LABEL: HloModule main
// CHECK: (arg_tuple.{{[0-9]+}}: (f32[10,17], f32[17,19])) -> (f32[10,19])
// NO_TUPLES-LABEL: HloModule main{{.[0-9]*}}
// NO_TUPLES: ({{.+}}: f32[10,17], {{.+}}: f32[17,19]) -> f32[10,19]

View File

@ -211,7 +211,20 @@ void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
std::iota(input_mapping->begin(), input_mapping->end(), 0);
}
// Refine MLIR types based on new shape information.
static void RegisterDialects(mlir::DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
mlir::mhlo::registerAllMhloDialects(registry);
}
// Checks if functions can be inlined after TF -> HLO legalization. Currently
// TPU's are supported, to follow the behavior of inlining functions via the
// Graph based bridge in the TPUCompile op kernel.
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
return device_type == DEVICE_TPU_XLA_JIT;
}
} // namespace
Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
mlir::ModuleOp module) {
auto producer_or = GetTfGraphProducerVersion(module);
@ -262,20 +275,6 @@ Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
return Status::OK();
}
static void RegisterDialects(mlir::DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
mlir::mhlo::registerAllMhloDialects(registry);
}
// Checks if functions can be inlined after TF -> HLO legalization. Currently
// TPU's are supported, to follow the behavior of inlining functions via the
// Graph based bridge in the TPUCompile op kernel.
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
return device_type == DEVICE_TPU_XLA_JIT;
}
} // namespace
void CreateConvertMlirToXlaHloPipeline(
mlir::OpPassManager& pm, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
@ -337,13 +336,9 @@ void CreateConvertMlirToXlaHloPipeline(
mlir::mhlo::createSinkConstantsToControlFlowPass());
}
Status ConvertMLIRToXlaComputation(
mlir::ModuleOp module_op, llvm::StringRef device_type,
xla::XlaComputation* xla_computation, bool use_tuple_args,
bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
mlir::PassManager tf2xla(module_op.getContext());
applyTensorflowAndCLOptions(tf2xla);
CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
@ -370,6 +365,32 @@ Status ConvertMLIRToXlaComputation(
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_legalize_hlo", module_op);
return Status::OK();
}
Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
TF_RETURN_IF_ERROR(
LegalizeToHlo(module_op, device_type, custom_legalization_passes));
mlir::Block& block = module_op.lookupSymbol<mlir::FuncOp>("main").front();
return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
}
Status ConvertMLIRToXlaComputation(
mlir::ModuleOp module_op, llvm::StringRef device_type,
xla::XlaComputation* xla_computation, bool use_tuple_args,
bool return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
TF_RETURN_IF_ERROR(
LegalizeToHlo(module_op, device_type, custom_legalization_passes));
xla::HloProto hlo_proto;
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
use_tuple_args, return_tuple,
@ -378,14 +399,9 @@ Status ConvertMLIRToXlaComputation(
return Status::OK();
}
Status CompileMlirToXlaHlo(
Status CompileMlirSetup(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
bool use_resource_updates_for_aliases,
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
XlaHelpers::ShapeRepresentationFn* shape_representation_fn) {
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
@ -395,16 +411,39 @@ Status CompileMlirToXlaHlo(
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
if (!shape_representation_fn)
shape_representation_fn = IdentityShapeRepresentationFn();
if (!*shape_representation_fn)
*shape_representation_fn = IdentityShapeRepresentationFn();
return Status::OK();
}
Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
XlaHelpers::ShapeRepresentationFn shape_representation_fn;
TF_RETURN_IF_ERROR(
CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
compilation_result->computation = std::make_shared<xla::XlaComputation>();
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, device_type, compilation_result->computation.get(),
use_tuple_args, use_return_tuple, shape_representation_fn,
custom_legalization_passes));
TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
returns, device_type,
custom_legalization_passes));
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op);
return Status::OK();
}
Status PopulateResultIOInfo(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
bool use_tuple_args, bool use_resource_updates_for_aliases,
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result) {
// Construct mapping from XlaComputation's arg to input edges of execute
// node.
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
@ -426,6 +465,29 @@ Status CompileMlirToXlaHlo(
return Status::OK();
}
Status CompileMlirToXlaHlo(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
bool use_resource_updates_for_aliases,
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
TF_RETURN_IF_ERROR(
CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
compilation_result->computation = std::make_shared<xla::XlaComputation>();
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, device_type, compilation_result->computation.get(),
use_tuple_args, use_return_tuple, shape_representation_fn,
custom_legalization_passes));
return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
use_resource_updates_for_aliases,
shape_representation_fn, compilation_result);
}
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
llvm::StringRef device_type, bool use_tuple_args,
@ -516,18 +578,13 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
return params;
}
Status CompileGraphToXlaHlo(
Status CompileGraphSetup(
mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
RewriteWithArgs(module_op, args));
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
arg_shapes.reserve(remaining_params.size());
for (unsigned idx : remaining_params) {
std::vector<int>* remaining_params,
llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
arg_shapes.reserve(remaining_params->size());
for (unsigned idx : *remaining_params) {
const auto& arg = args[idx];
TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
GetTensorShapeFromXlaArgument(arg));
@ -539,10 +596,39 @@ Status CompileGraphToXlaHlo(
applyTensorflowAndCLOptions(pm);
mlir::TF::StandardPipelineOptions tf_options;
mlir::TF::CreateTFStandardPipeline(pm, tf_options);
{
mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
}
mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
return Status::OK();
}
Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
std::vector<int> remaining_params;
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
TF_RETURN_IF_ERROR(
CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
return BuildHloFromTf(module_op, builder, xla_params, returns, arg_shapes,
device_type, custom_legalization_passes);
}
Status CompileGraphToXlaHlo(
mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
std::vector<int> remaining_params;
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
TF_RETURN_IF_ERROR(
CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
auto status = CompileMlirToXlaHlo(
module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
@ -552,6 +638,49 @@ Status CompileGraphToXlaHlo(
return status;
}
Status GraphToModule(const Graph& graph,
llvm::ArrayRef<std::string> control_rets,
const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
mlir::MLIRContext* context,
mlir::OwningModuleRef* module) {
RegisterDialects(context->getDialectRegistry());
GraphImportConfig config;
config.graph_as_function = true;
config.control_outputs = control_rets;
// Disable shape inference during import as some TensorFlow op fails during
// shape inference with dynamic shaped operands. This in turn causes the
// import to fail. Shape inference during import is going to be removed and
// the shape inference pass is run early in the pass pipeline, shape inference
// during import is not necessary.
config.enable_shape_inference = false;
auto module_or =
ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
if (!module_or.ok()) return module_or.status();
*module = std::move(module_or.ValueOrDie());
return Status::OK();
}
Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
llvm::ArrayRef<XlaArgument> args,
llvm::ArrayRef<std::string> control_rets,
llvm::StringRef device_type,
const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
mlir::MLIRContext context;
mlir::OwningModuleRef module;
TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
&context, &module));
return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
device_type, custom_legalization_passes);
}
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
@ -562,22 +691,10 @@ Status CompileGraphToXlaHlo(
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
mlir::MLIRContext context;
RegisterDialects(context.getDialectRegistry());
GraphImportConfig config;
config.graph_as_function = true;
config.control_outputs = control_rets;
// Disable shape inference during import as some TensorFlow op fails during
// shape inference with dynamic shaped operands. This in turn causes the
// import to fail. Shape inference during import is going to be removed and
// the shape inference pass is run early in the pass pipeline, shape inference
// during import is not necessary.
config.enable_shape_inference = false;
auto module_or =
ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
if (!module_or.ok()) return module_or.status();
mlir::ModuleOp module_op = module_or.ValueOrDie().get();
return CompileGraphToXlaHlo(module_op, args, device_type, use_tuple_args,
mlir::OwningModuleRef module;
TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
&context, &module));
return CompileGraphToXlaHlo(module.get(), args, device_type, use_tuple_args,
/*use_return_tuple=*/true,
shape_representation_fn, compilation_result,
custom_legalization_passes);

View File

@ -81,6 +81,30 @@ struct TensorOrResourceShape {
bool is_resource = false;
};
// Refine MLIR types based on new shape information.
Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
mlir::ModuleOp module);
// Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
// inputs to module_op that have already been added to the XlaBuilder. returns
// are the returned XlaOps.
Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes);
// Apply shape, description, and resource information to inputs and outputs
// in the XlaCompilationResult. This should be called after
// compilation_result->computation was set.
Status PopulateResultIOInfo(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
bool use_tuple_args, bool use_resource_updates_for_aliases,
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result);
// Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
// stores them in CompilationResult.
// TODO(hinsu): Migrate options to separate struct.
@ -128,6 +152,21 @@ Status CompileGraphToXlaHlo(
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes = {});
// Compiles a Graph from TF to HLO and adds the resulting HLO to the
// XlaBuilder. This function adds HLO to a larger HLO computation, so
// HLO-level inputs are supplied, and HLO-level outputs are produced.
// xla_params is the HLO-level inputs and returns is the HLO-level outputs.
Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
llvm::ArrayRef<XlaArgument> args,
llvm::ArrayRef<std::string> control_rets,
llvm::StringRef device_type,
const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes = {});
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/compiler/mlir/utils/string_container_utils.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
#include "tensorflow/compiler/tf2xla/xla_argument.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -233,8 +234,64 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
} // anonymous namespace
static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
mlir::ModuleOp module_op, llvm::raw_ostream& output) {
// Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so
// to make this test comparable to other compile tests, the test implements
// the remaining parts of the conversion.
Status CompileMlirToXlaHloViaBuilder(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
llvm::StringRef device_type, XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
// This call to RefineShapes is redundant with the call in BuildHloFromTf.
// It's here so xla::Parameters that are created form block.getArguments will
// have the proper shapes.
TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
mlir::FuncOp main = module_op.lookupSymbol<mlir::FuncOp>("main");
mlir::Block& block = main.getRegion().front();
xla::XlaBuilder builder("main");
// Create xla_params.
std::vector<xla::XlaOp> xla_params;
for (mlir::BlockArgument& arg : block.getArguments()) {
auto num = arg.getArgNumber();
xla::Shape shape = xla::TypeToShape(arg.getType());
xla::XlaOp argop =
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
xla_params.push_back(argop);
}
std::vector<xla::XlaOp> returns(1);
TF_RETURN_IF_ERROR(BuildHloFromTf(module_op, builder, xla_params, returns,
arg_shapes, device_type,
custom_legalization_passes));
xla::XlaOp return_value;
if (returns.size() == 1)
return_value = returns[0];
else
return_value = xla::Tuple(&builder, returns);
TF_ASSIGN_OR_RETURN(
xla::XlaComputation computation,
return_value.valid() ? builder.Build(return_value) : builder.Build());
auto hlo_module = computation.proto();
xla::HloProto hlo_proto;
hlo_proto.mutable_hlo_module()->Swap(&hlo_module);
compilation_result->computation = std::make_shared<xla::XlaComputation>();
xla::XlaComputation* xla_computation = compilation_result->computation.get();
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
XlaHelpers::ShapeRepresentationFn shape_representation_fn =
IdentityShapeRepresentationFn();
return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false,
/*use_resource_updates_for_aliases=*/false,
shape_representation_fn, compilation_result);
}
static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl(
mlir::ModuleOp module_op, llvm::raw_ostream& output, bool via_builder) {
if (!module_op) return mlir::failure();
llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
@ -245,12 +302,21 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
return mlir::failure();
}
auto device_type = "XLA_CPU_JIT";
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes{};
XlaCompilationResult compilation_result;
auto compilation_status = CompileMlirToXlaHlo(
module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
IdentityShapeRepresentationFn(), &compilation_result,
/*custom_legalization_passes=*/{});
auto compilation_status =
via_builder
? CompileMlirToXlaHloViaBuilder(module_op, arg_shapes, device_type,
&compilation_result,
custom_legalization_passes)
: CompileMlirToXlaHlo(module_op, arg_shapes, device_type,
emit_use_tuple_arg, emit_return_tuple,
/*use_resource_updates_for_aliases=*/true,
IdentityShapeRepresentationFn(),
&compilation_result,
custom_legalization_passes);
if (!compilation_status.ok()) {
LOG(ERROR) << "TF/XLA compilation failed: "
<< compilation_status.ToString();
@ -326,12 +392,27 @@ static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate(
return mlir::success();
}
static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
mlir::ModuleOp module_op, llvm::raw_ostream& output) {
return MlirTfToHloTextTranslateFunctionImpl(module_op, output, false);
}
static mlir::LogicalResult MlirTfToHloTextViaBuilderTranslateFunction(
mlir::ModuleOp module_op, llvm::raw_ostream& output) {
return MlirTfToHloTextTranslateFunctionImpl(module_op, output, true);
}
} // namespace tensorflow
static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate(
"mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction,
tensorflow::RegisterMlirInputDialects);
static mlir::TranslateFromMLIRRegistration MlirTfToHloTextViaBuilderTranslate(
"mlir-tf-to-hlo-text-via-builder",
tensorflow::MlirTfToHloTextViaBuilderTranslateFunction,
tensorflow::RegisterMlirInputDialects);
static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate(
"mlir-tf-graph-to-hlo-text",
tensorflow::MlirTfGraphToHloTextTranslateFunction,