Add CompileGraphToXlaBuilder to compile_mlir_util.h. Refactor internal functions to support both the new CompileGraphToXlaBuilder and the existing CompileGraphToHlo.
CompileGraphToXlaBuilder will be used for MLIR legalization in the old bridge. PiperOrigin-RevId: 346679170 Change-Id: I34b4e7b8b41ef3aa6a6a65426ecee4d113a97979
This commit is contained in:
		
							parent
							
								
									2090fe76f2
								
							
						
					
					
						commit
						c6293e9bfc
					
				@ -1742,6 +1742,7 @@ cc_library(
 | 
			
		||||
        ":translate_cl_options",
 | 
			
		||||
        "//tensorflow/compiler/mlir:string_container_utils",
 | 
			
		||||
        "//tensorflow/compiler/mlir/xla:translate_cl_options",
 | 
			
		||||
        "//tensorflow/compiler/mlir/xla:type_to_shape",
 | 
			
		||||
        "//tensorflow/compiler/tf2xla:xla_argument",
 | 
			
		||||
        "//tensorflow/compiler/tf2xla:xla_helpers",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo",
 | 
			
		||||
@ -1755,7 +1756,6 @@ cc_library(
 | 
			
		||||
        "@llvm-project//llvm:Support",
 | 
			
		||||
        "@llvm-project//mlir:IR",
 | 
			
		||||
        "@llvm-project//mlir:Parser",
 | 
			
		||||
        "@llvm-project//mlir:Pass",
 | 
			
		||||
        "@llvm-project//mlir:StandardOps",
 | 
			
		||||
        "@llvm-project//mlir:Support",
 | 
			
		||||
        "@llvm-project//mlir:Translation",
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-return-tuple | FileCheck %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck -check-prefix=TUPLE-ARGS %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_RET_TUPLE %s
 | 
			
		||||
 | 
			
		||||
module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
  func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
 | 
			
		||||
@ -36,3 +38,16 @@ module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
// TUPLE-ARGS-NEXT:  // XlaInputShape (f32[], f32[])
 | 
			
		||||
// TUPLE-ARGS-NEXT:  // XlaOutputShape (f32[])
 | 
			
		||||
// TUPLE-ARGS-NEXT:  // XlaOutputDescription type=float shape=()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// NO_RET_TUPLE-LABEL: HloModule main{{[.0-9]*}}
 | 
			
		||||
// NO_RET_TUPLE:       ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
 | 
			
		||||
// NO_RET_TUPLE-NEXT:    %[[ARG0]] = f32[] parameter(0)
 | 
			
		||||
// NO_RET_TUPLE-NEXT:    %[[ARG1]] = f32[] parameter(1)
 | 
			
		||||
// NO_RET_TUPLE-NEXT:    ROOT [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
 | 
			
		||||
 | 
			
		||||
// NO_RET_TUPLE:       // InputMapping {0, 1}
 | 
			
		||||
// NO_RET_TUPLE-NEXT:  // XlaInputShape f32[]
 | 
			
		||||
// NO_RET_TUPLE-NEXT:  // XlaInputShape f32[]
 | 
			
		||||
// NO_RET_TUPLE-NEXT:  // XlaOutputShape (f32[])
 | 
			
		||||
// NO_RET_TUPLE-NEXT:  // XlaOutputDescription type=float shape=()
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -emit-use-tuple-args -emit-return-tuple | FileCheck %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=: | FileCheck -check-prefix=NO_TUPLES %s
 | 
			
		||||
 | 
			
		||||
module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
  func @main() -> (tensor<0xi32>, tensor<0xi32>) {
 | 
			
		||||
@ -14,3 +16,9 @@ module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
// CHECK:         [[CONSTANT:%.*]] = s32[0]{0} constant({})
 | 
			
		||||
// CHECK:         ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
 | 
			
		||||
// CHECK:       }
 | 
			
		||||
 | 
			
		||||
// NO_TUPLES-LABEL: HloModule main{{.[0-9+]}}
 | 
			
		||||
// NO_TUPLES:       ENTRY %main.{{[0-9+]}} () -> (s32[0], s32[0]) {
 | 
			
		||||
// NO_TUPLES:         [[CONSTANT:%.*]] = s32[0]{0} constant({})
 | 
			
		||||
// NO_TUPLES:         ROOT %tuple.{{[0-9]+}} = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} [[CONSTANT]], s32[0]{0} [[CONSTANT]])
 | 
			
		||||
// NO_TUPLES:       }
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=8,16,16,64:64 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=8,16,16,64:64 | FileCheck %s
 | 
			
		||||
 | 
			
		||||
module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
  func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) {
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 -emit-use-tuple-args -emit-return-tuple | FileCheck %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
 | 
			
		||||
// RUN: tf-mlir-translate -mlir-tf-to-hlo-text-via-builder %s -tf-input-shapes=10,17:17,19 | FileCheck -check-prefix=NO_TUPLES %s
 | 
			
		||||
 | 
			
		||||
module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
  func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
 | 
			
		||||
@ -9,3 +11,6 @@ module attributes {tf.versions = {producer = 179 : i32}} {
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: HloModule main
 | 
			
		||||
// CHECK:       (arg_tuple.{{[0-9]+}}: (f32[10,17], f32[17,19])) -> (f32[10,19])
 | 
			
		||||
 | 
			
		||||
// NO_TUPLES-LABEL: HloModule main{{.[0-9]*}}
 | 
			
		||||
// NO_TUPLES:       ({{.+}}: f32[10,17], {{.+}}: f32[17,19]) -> f32[10,19]
 | 
			
		||||
 | 
			
		||||
@ -211,7 +211,20 @@ void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
 | 
			
		||||
  std::iota(input_mapping->begin(), input_mapping->end(), 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Refine MLIR types based on new shape information.
 | 
			
		||||
static void RegisterDialects(mlir::DialectRegistry& registry) {
 | 
			
		||||
  mlir::RegisterAllTensorFlowDialects(registry);
 | 
			
		||||
  mlir::mhlo::registerAllMhloDialects(registry);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Checks if functions can be inlined after TF -> HLO legalization. Currently
 | 
			
		||||
// TPU's are supported, to follow the behavior of inlining functions via the
 | 
			
		||||
// Graph based bridge in the TPUCompile op kernel.
 | 
			
		||||
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
 | 
			
		||||
  return device_type == DEVICE_TPU_XLA_JIT;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  //  namespace
 | 
			
		||||
 | 
			
		||||
Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
                    mlir::ModuleOp module) {
 | 
			
		||||
  auto producer_or = GetTfGraphProducerVersion(module);
 | 
			
		||||
@ -262,20 +275,6 @@ Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void RegisterDialects(mlir::DialectRegistry& registry) {
 | 
			
		||||
  mlir::RegisterAllTensorFlowDialects(registry);
 | 
			
		||||
  mlir::mhlo::registerAllMhloDialects(registry);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Checks if functions can be inlined after TF -> HLO legalization. Currently
 | 
			
		||||
// TPU's are supported, to follow the behavior of inlining functions via the
 | 
			
		||||
// Graph based bridge in the TPUCompile op kernel.
 | 
			
		||||
bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
 | 
			
		||||
  return device_type == DEVICE_TPU_XLA_JIT;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  //  namespace
 | 
			
		||||
 | 
			
		||||
void CreateConvertMlirToXlaHloPipeline(
 | 
			
		||||
    mlir::OpPassManager& pm, llvm::StringRef device_type,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
@ -337,13 +336,9 @@ void CreateConvertMlirToXlaHloPipeline(
 | 
			
		||||
      mlir::mhlo::createSinkConstantsToControlFlowPass());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status ConvertMLIRToXlaComputation(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::StringRef device_type,
 | 
			
		||||
    xla::XlaComputation* xla_computation, bool use_tuple_args,
 | 
			
		||||
    bool return_tuple,
 | 
			
		||||
    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
 | 
			
		||||
                     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                         custom_legalization_passes) {
 | 
			
		||||
  mlir::PassManager tf2xla(module_op.getContext());
 | 
			
		||||
  applyTensorflowAndCLOptions(tf2xla);
 | 
			
		||||
  CreateConvertMlirToXlaHloPipeline(tf2xla, device_type,
 | 
			
		||||
@ -370,6 +365,32 @@ Status ConvertMLIRToXlaComputation(
 | 
			
		||||
  if (VLOG_IS_ON(1))
 | 
			
		||||
    tensorflow::DumpMlirOpToFile("mlir_compile_legalize_hlo", module_op);
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
 | 
			
		||||
                           llvm::ArrayRef<xla::XlaOp> xla_params,
 | 
			
		||||
                           std::vector<xla::XlaOp>& returns,
 | 
			
		||||
                           llvm::StringRef device_type,
 | 
			
		||||
                           llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                               custom_legalization_passes) {
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      LegalizeToHlo(module_op, device_type, custom_legalization_passes));
 | 
			
		||||
 | 
			
		||||
  mlir::Block& block = module_op.lookupSymbol<mlir::FuncOp>("main").front();
 | 
			
		||||
  return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status ConvertMLIRToXlaComputation(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::StringRef device_type,
 | 
			
		||||
    xla::XlaComputation* xla_computation, bool use_tuple_args,
 | 
			
		||||
    bool return_tuple,
 | 
			
		||||
    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      LegalizeToHlo(module_op, device_type, custom_legalization_passes));
 | 
			
		||||
 | 
			
		||||
  xla::HloProto hlo_proto;
 | 
			
		||||
  TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
 | 
			
		||||
                                               use_tuple_args, return_tuple,
 | 
			
		||||
@ -378,14 +399,9 @@ Status ConvertMLIRToXlaComputation(
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status CompileMlirToXlaHlo(
 | 
			
		||||
Status CompileMlirSetup(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
 | 
			
		||||
    bool use_resource_updates_for_aliases,
 | 
			
		||||
    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    XlaCompilationResult* compilation_result,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
    XlaHelpers::ShapeRepresentationFn* shape_representation_fn) {
 | 
			
		||||
  if (VLOG_IS_ON(1))
 | 
			
		||||
    tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
 | 
			
		||||
 | 
			
		||||
@ -395,16 +411,39 @@ Status CompileMlirToXlaHlo(
 | 
			
		||||
  if (VLOG_IS_ON(1))
 | 
			
		||||
    tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
 | 
			
		||||
 | 
			
		||||
  if (!shape_representation_fn)
 | 
			
		||||
    shape_representation_fn = IdentityShapeRepresentationFn();
 | 
			
		||||
  if (!*shape_representation_fn)
 | 
			
		||||
    *shape_representation_fn = IdentityShapeRepresentationFn();
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
 | 
			
		||||
                      llvm::ArrayRef<xla::XlaOp> xla_params,
 | 
			
		||||
                      std::vector<xla::XlaOp>& returns,
 | 
			
		||||
                      llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
                      llvm::StringRef device_type,
 | 
			
		||||
                      llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                          custom_legalization_passes) {
 | 
			
		||||
  XlaHelpers::ShapeRepresentationFn shape_representation_fn;
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
 | 
			
		||||
 | 
			
		||||
  // Convert MLIR module to XLA HLO proto contained in XlaComputation.
 | 
			
		||||
  compilation_result->computation = std::make_shared<xla::XlaComputation>();
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
 | 
			
		||||
      module_op, device_type, compilation_result->computation.get(),
 | 
			
		||||
      use_tuple_args, use_return_tuple, shape_representation_fn,
 | 
			
		||||
      custom_legalization_passes));
 | 
			
		||||
  TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
 | 
			
		||||
                                         returns, device_type,
 | 
			
		||||
                                         custom_legalization_passes));
 | 
			
		||||
 | 
			
		||||
  if (VLOG_IS_ON(1))
 | 
			
		||||
    tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op);
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status PopulateResultIOInfo(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
    bool use_tuple_args, bool use_resource_updates_for_aliases,
 | 
			
		||||
    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    XlaCompilationResult* compilation_result) {
 | 
			
		||||
  // Construct mapping from XlaComputation's arg to input edges of execute
 | 
			
		||||
  // node.
 | 
			
		||||
  GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
 | 
			
		||||
@ -426,6 +465,29 @@ Status CompileMlirToXlaHlo(
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status CompileMlirToXlaHlo(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
 | 
			
		||||
    bool use_resource_updates_for_aliases,
 | 
			
		||||
    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    XlaCompilationResult* compilation_result,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      CompileMlirSetup(module_op, arg_shapes, &shape_representation_fn));
 | 
			
		||||
 | 
			
		||||
  // Convert MLIR module to XLA HLO proto contained in XlaComputation.
 | 
			
		||||
  compilation_result->computation = std::make_shared<xla::XlaComputation>();
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
 | 
			
		||||
      module_op, device_type, compilation_result->computation.get(),
 | 
			
		||||
      use_tuple_args, use_return_tuple, shape_representation_fn,
 | 
			
		||||
      custom_legalization_passes));
 | 
			
		||||
 | 
			
		||||
  return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
 | 
			
		||||
                              use_resource_updates_for_aliases,
 | 
			
		||||
                              shape_representation_fn, compilation_result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status CompileSerializedMlirToXlaHlo(
 | 
			
		||||
    llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
 | 
			
		||||
    llvm::StringRef device_type, bool use_tuple_args,
 | 
			
		||||
@ -516,18 +578,13 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
 | 
			
		||||
  return params;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status CompileGraphToXlaHlo(
 | 
			
		||||
Status CompileGraphSetup(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
 | 
			
		||||
    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
 | 
			
		||||
    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    XlaCompilationResult* compilation_result,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
 | 
			
		||||
                      RewriteWithArgs(module_op, args));
 | 
			
		||||
  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
 | 
			
		||||
  arg_shapes.reserve(remaining_params.size());
 | 
			
		||||
  for (unsigned idx : remaining_params) {
 | 
			
		||||
    std::vector<int>* remaining_params,
 | 
			
		||||
    llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
 | 
			
		||||
  arg_shapes.reserve(remaining_params->size());
 | 
			
		||||
  for (unsigned idx : *remaining_params) {
 | 
			
		||||
    const auto& arg = args[idx];
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
 | 
			
		||||
                        GetTensorShapeFromXlaArgument(arg));
 | 
			
		||||
@ -539,10 +596,39 @@ Status CompileGraphToXlaHlo(
 | 
			
		||||
  applyTensorflowAndCLOptions(pm);
 | 
			
		||||
  mlir::TF::StandardPipelineOptions tf_options;
 | 
			
		||||
  mlir::TF::CreateTFStandardPipeline(pm, tf_options);
 | 
			
		||||
  {
 | 
			
		||||
    mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
 | 
			
		||||
    if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
 | 
			
		||||
  if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
 | 
			
		||||
                          llvm::ArrayRef<xla::XlaOp> xla_params,
 | 
			
		||||
                          std::vector<xla::XlaOp>& returns,
 | 
			
		||||
                          llvm::ArrayRef<XlaArgument> args,
 | 
			
		||||
                          llvm::StringRef device_type,
 | 
			
		||||
                          llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                              custom_legalization_passes) {
 | 
			
		||||
  std::vector<int> remaining_params;
 | 
			
		||||
  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
 | 
			
		||||
  return BuildHloFromTf(module_op, builder, xla_params, returns, arg_shapes,
 | 
			
		||||
                        device_type, custom_legalization_passes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status CompileGraphToXlaHlo(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
 | 
			
		||||
    llvm::StringRef device_type, bool use_tuple_args, bool use_return_tuple,
 | 
			
		||||
    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    XlaCompilationResult* compilation_result,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
  std::vector<int> remaining_params;
 | 
			
		||||
  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
 | 
			
		||||
  TF_RETURN_IF_ERROR(
 | 
			
		||||
      CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
 | 
			
		||||
 | 
			
		||||
  auto status = CompileMlirToXlaHlo(
 | 
			
		||||
      module_op, arg_shapes, device_type, use_tuple_args, use_return_tuple,
 | 
			
		||||
@ -552,6 +638,49 @@ Status CompileGraphToXlaHlo(
 | 
			
		||||
  return status;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status GraphToModule(const Graph& graph,
 | 
			
		||||
                     llvm::ArrayRef<std::string> control_rets,
 | 
			
		||||
                     const FunctionLibraryDefinition& flib_def,
 | 
			
		||||
                     const GraphDebugInfo& debug_info,
 | 
			
		||||
                     mlir::MLIRContext* context,
 | 
			
		||||
                     mlir::OwningModuleRef* module) {
 | 
			
		||||
  RegisterDialects(context->getDialectRegistry());
 | 
			
		||||
  GraphImportConfig config;
 | 
			
		||||
  config.graph_as_function = true;
 | 
			
		||||
  config.control_outputs = control_rets;
 | 
			
		||||
  // Disable shape inference during import as some TensorFlow op fails during
 | 
			
		||||
  // shape inference with dynamic shaped operands. This in turn causes the
 | 
			
		||||
  // import to fail. Shape inference during import is going to be removed and
 | 
			
		||||
  // the shape inference pass is run early in the pass pipeline, shape inference
 | 
			
		||||
  // during import is not necessary.
 | 
			
		||||
  config.enable_shape_inference = false;
 | 
			
		||||
  auto module_or =
 | 
			
		||||
      ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
 | 
			
		||||
  if (!module_or.ok()) return module_or.status();
 | 
			
		||||
 | 
			
		||||
  *module = std::move(module_or.ValueOrDie());
 | 
			
		||||
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
 | 
			
		||||
                         llvm::ArrayRef<xla::XlaOp> xla_params,
 | 
			
		||||
                         std::vector<xla::XlaOp>& returns,
 | 
			
		||||
                         llvm::ArrayRef<XlaArgument> args,
 | 
			
		||||
                         llvm::ArrayRef<std::string> control_rets,
 | 
			
		||||
                         llvm::StringRef device_type,
 | 
			
		||||
                         const FunctionLibraryDefinition& flib_def,
 | 
			
		||||
                         const GraphDebugInfo& debug_info,
 | 
			
		||||
                         llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                             custom_legalization_passes) {
 | 
			
		||||
  mlir::MLIRContext context;
 | 
			
		||||
  mlir::OwningModuleRef module;
 | 
			
		||||
  TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
 | 
			
		||||
                                   &context, &module));
 | 
			
		||||
  return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
 | 
			
		||||
                            device_type, custom_legalization_passes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status CompileGraphToXlaHlo(
 | 
			
		||||
    const Graph& graph, llvm::ArrayRef<XlaArgument> args,
 | 
			
		||||
    llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
 | 
			
		||||
@ -562,22 +691,10 @@ Status CompileGraphToXlaHlo(
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
  mlir::MLIRContext context;
 | 
			
		||||
  RegisterDialects(context.getDialectRegistry());
 | 
			
		||||
  GraphImportConfig config;
 | 
			
		||||
  config.graph_as_function = true;
 | 
			
		||||
  config.control_outputs = control_rets;
 | 
			
		||||
  // Disable shape inference during import as some TensorFlow op fails during
 | 
			
		||||
  // shape inference with dynamic shaped operands. This in turn causes the
 | 
			
		||||
  // import to fail. Shape inference during import is going to be removed and
 | 
			
		||||
  // the shape inference pass is run early in the pass pipeline, shape inference
 | 
			
		||||
  // during import is not necessary.
 | 
			
		||||
  config.enable_shape_inference = false;
 | 
			
		||||
  auto module_or =
 | 
			
		||||
      ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
 | 
			
		||||
  if (!module_or.ok()) return module_or.status();
 | 
			
		||||
 | 
			
		||||
  mlir::ModuleOp module_op = module_or.ValueOrDie().get();
 | 
			
		||||
  return CompileGraphToXlaHlo(module_op, args, device_type, use_tuple_args,
 | 
			
		||||
  mlir::OwningModuleRef module;
 | 
			
		||||
  TF_RETURN_IF_ERROR(GraphToModule(graph, control_rets, flib_def, debug_info,
 | 
			
		||||
                                   &context, &module));
 | 
			
		||||
  return CompileGraphToXlaHlo(module.get(), args, device_type, use_tuple_args,
 | 
			
		||||
                              /*use_return_tuple=*/true,
 | 
			
		||||
                              shape_representation_fn, compilation_result,
 | 
			
		||||
                              custom_legalization_passes);
 | 
			
		||||
 | 
			
		||||
@ -81,6 +81,30 @@ struct TensorOrResourceShape {
 | 
			
		||||
  bool is_resource = false;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Refine MLIR types based on new shape information.
 | 
			
		||||
Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
                    mlir::ModuleOp module);
 | 
			
		||||
 | 
			
		||||
// Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
 | 
			
		||||
// inputs to module_op that have already been added to the XlaBuilder. returns
 | 
			
		||||
// are the returned XlaOps.
 | 
			
		||||
Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
 | 
			
		||||
                      llvm::ArrayRef<xla::XlaOp> xla_params,
 | 
			
		||||
                      std::vector<xla::XlaOp>& returns,
 | 
			
		||||
                      llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
                      llvm::StringRef device_type,
 | 
			
		||||
                      llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                          custom_legalization_passes);
 | 
			
		||||
 | 
			
		||||
// Apply shape, description, and resource information to inputs and outputs
 | 
			
		||||
// in the XlaCompilationResult. This should be called after
 | 
			
		||||
// compilation_result->computation was set.
 | 
			
		||||
Status PopulateResultIOInfo(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
    bool use_tuple_args, bool use_resource_updates_for_aliases,
 | 
			
		||||
    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
 | 
			
		||||
    XlaCompilationResult* compilation_result);
 | 
			
		||||
 | 
			
		||||
// Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
 | 
			
		||||
// stores them in CompilationResult.
 | 
			
		||||
// TODO(hinsu): Migrate options to separate struct.
 | 
			
		||||
@ -128,6 +152,21 @@ Status CompileGraphToXlaHlo(
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes = {});
 | 
			
		||||
 | 
			
		||||
// Compiles a Graph from TF to HLO and adds the resulting HLO to the
 | 
			
		||||
// XlaBuilder. This function adds HLO to a larger HLO computation, so
 | 
			
		||||
// HLO-level inputs are supplied, and HLO-level outputs are produced.
 | 
			
		||||
// xla_params is the HLO-level inputs and returns is the HLO-level outputs.
 | 
			
		||||
Status BuildHloFromGraph(const Graph& graph, xla::XlaBuilder& builder,
 | 
			
		||||
                         llvm::ArrayRef<xla::XlaOp> xla_params,
 | 
			
		||||
                         std::vector<xla::XlaOp>& returns,
 | 
			
		||||
                         llvm::ArrayRef<XlaArgument> args,
 | 
			
		||||
                         llvm::ArrayRef<std::string> control_rets,
 | 
			
		||||
                         llvm::StringRef device_type,
 | 
			
		||||
                         const FunctionLibraryDefinition& flib_def,
 | 
			
		||||
                         const GraphDebugInfo& debug_info,
 | 
			
		||||
                         llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
                             custom_legalization_passes = {});
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
 | 
			
		||||
 | 
			
		||||
@ -40,6 +40,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
 | 
			
		||||
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
 | 
			
		||||
#include "tensorflow/compiler/mlir/utils/string_container_utils.h"
 | 
			
		||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 | 
			
		||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
 | 
			
		||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
 | 
			
		||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
 | 
			
		||||
@ -233,8 +234,64 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
 | 
			
		||||
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
 | 
			
		||||
static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::raw_ostream& output) {
 | 
			
		||||
// Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so
 | 
			
		||||
// to make this test comparable to other compile tests, the test implements
 | 
			
		||||
// the remaining parts of the conversion.
 | 
			
		||||
Status CompileMlirToXlaHloViaBuilder(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
 | 
			
		||||
    llvm::StringRef device_type, XlaCompilationResult* compilation_result,
 | 
			
		||||
    llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
        custom_legalization_passes) {
 | 
			
		||||
  // This call to RefineShapes is redundant with the call in BuildHloFromTf.
 | 
			
		||||
  // It's here so xla::Parameters that are created form block.getArguments will
 | 
			
		||||
  // have the proper shapes.
 | 
			
		||||
  TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
 | 
			
		||||
 | 
			
		||||
  mlir::FuncOp main = module_op.lookupSymbol<mlir::FuncOp>("main");
 | 
			
		||||
  mlir::Block& block = main.getRegion().front();
 | 
			
		||||
  xla::XlaBuilder builder("main");
 | 
			
		||||
 | 
			
		||||
  // Create xla_params.
 | 
			
		||||
  std::vector<xla::XlaOp> xla_params;
 | 
			
		||||
  for (mlir::BlockArgument& arg : block.getArguments()) {
 | 
			
		||||
    auto num = arg.getArgNumber();
 | 
			
		||||
    xla::Shape shape = xla::TypeToShape(arg.getType());
 | 
			
		||||
    xla::XlaOp argop =
 | 
			
		||||
        xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
 | 
			
		||||
    xla_params.push_back(argop);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<xla::XlaOp> returns(1);
 | 
			
		||||
  TF_RETURN_IF_ERROR(BuildHloFromTf(module_op, builder, xla_params, returns,
 | 
			
		||||
                                    arg_shapes, device_type,
 | 
			
		||||
                                    custom_legalization_passes));
 | 
			
		||||
 | 
			
		||||
  xla::XlaOp return_value;
 | 
			
		||||
  if (returns.size() == 1)
 | 
			
		||||
    return_value = returns[0];
 | 
			
		||||
  else
 | 
			
		||||
    return_value = xla::Tuple(&builder, returns);
 | 
			
		||||
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(
 | 
			
		||||
      xla::XlaComputation computation,
 | 
			
		||||
      return_value.valid() ? builder.Build(return_value) : builder.Build());
 | 
			
		||||
  auto hlo_module = computation.proto();
 | 
			
		||||
  xla::HloProto hlo_proto;
 | 
			
		||||
  hlo_proto.mutable_hlo_module()->Swap(&hlo_module);
 | 
			
		||||
 | 
			
		||||
  compilation_result->computation = std::make_shared<xla::XlaComputation>();
 | 
			
		||||
  xla::XlaComputation* xla_computation = compilation_result->computation.get();
 | 
			
		||||
  *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
 | 
			
		||||
 | 
			
		||||
  XlaHelpers::ShapeRepresentationFn shape_representation_fn =
 | 
			
		||||
      IdentityShapeRepresentationFn();
 | 
			
		||||
  return PopulateResultIOInfo(module_op, arg_shapes, /*use_tuple_args=*/false,
 | 
			
		||||
                              /*use_resource_updates_for_aliases=*/false,
 | 
			
		||||
                              shape_representation_fn, compilation_result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::raw_ostream& output, bool via_builder) {
 | 
			
		||||
  if (!module_op) return mlir::failure();
 | 
			
		||||
 | 
			
		||||
  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
 | 
			
		||||
@ -245,12 +302,21 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
 | 
			
		||||
    return mlir::failure();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto device_type = "XLA_CPU_JIT";
 | 
			
		||||
  llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
 | 
			
		||||
      custom_legalization_passes{};
 | 
			
		||||
  XlaCompilationResult compilation_result;
 | 
			
		||||
  auto compilation_status = CompileMlirToXlaHlo(
 | 
			
		||||
      module_op, arg_shapes, /*device_type=*/"XLA_CPU_JIT", emit_use_tuple_arg,
 | 
			
		||||
      emit_return_tuple, /*use_resource_updates_for_aliases=*/true,
 | 
			
		||||
      IdentityShapeRepresentationFn(), &compilation_result,
 | 
			
		||||
      /*custom_legalization_passes=*/{});
 | 
			
		||||
  auto compilation_status =
 | 
			
		||||
      via_builder
 | 
			
		||||
          ? CompileMlirToXlaHloViaBuilder(module_op, arg_shapes, device_type,
 | 
			
		||||
                                          &compilation_result,
 | 
			
		||||
                                          custom_legalization_passes)
 | 
			
		||||
          : CompileMlirToXlaHlo(module_op, arg_shapes, device_type,
 | 
			
		||||
                                emit_use_tuple_arg, emit_return_tuple,
 | 
			
		||||
                                /*use_resource_updates_for_aliases=*/true,
 | 
			
		||||
                                IdentityShapeRepresentationFn(),
 | 
			
		||||
                                &compilation_result,
 | 
			
		||||
                                custom_legalization_passes);
 | 
			
		||||
  if (!compilation_status.ok()) {
 | 
			
		||||
    LOG(ERROR) << "TF/XLA compilation failed: "
 | 
			
		||||
               << compilation_status.ToString();
 | 
			
		||||
@ -326,12 +392,27 @@ static mlir::LogicalResult MlirModuleToSerializedMlirStringAttrTranslate(
 | 
			
		||||
  return mlir::success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static mlir::LogicalResult MlirTfToHloTextTranslateFunction(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::raw_ostream& output) {
 | 
			
		||||
  return MlirTfToHloTextTranslateFunctionImpl(module_op, output, false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static mlir::LogicalResult MlirTfToHloTextViaBuilderTranslateFunction(
 | 
			
		||||
    mlir::ModuleOp module_op, llvm::raw_ostream& output) {
 | 
			
		||||
  return MlirTfToHloTextTranslateFunctionImpl(module_op, output, true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
static mlir::TranslateFromMLIRRegistration MlirTfToHloTextTranslate(
 | 
			
		||||
    "mlir-tf-to-hlo-text", tensorflow::MlirTfToHloTextTranslateFunction,
 | 
			
		||||
    tensorflow::RegisterMlirInputDialects);
 | 
			
		||||
 | 
			
		||||
static mlir::TranslateFromMLIRRegistration MlirTfToHloTextViaBuilderTranslate(
 | 
			
		||||
    "mlir-tf-to-hlo-text-via-builder",
 | 
			
		||||
    tensorflow::MlirTfToHloTextViaBuilderTranslateFunction,
 | 
			
		||||
    tensorflow::RegisterMlirInputDialects);
 | 
			
		||||
 | 
			
		||||
static mlir::TranslateFromMLIRRegistration MlirTfGraphToHloTextTranslate(
 | 
			
		||||
    "mlir-tf-graph-to-hlo-text",
 | 
			
		||||
    tensorflow::MlirTfGraphToHloTextTranslateFunction,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user