From a74f748fa4593f6e526f100b321be1d86adee12b Mon Sep 17 00:00:00 2001 From: Michael Delorimier Date: Wed, 2 Dec 2020 23:26:35 -0800 Subject: [PATCH] Add BuildHloFromMlirHlo to mlir_hlo_to_hlo.h. Refactor internal functions to support both the new BuildHloFromMlirHlo and the existing ConvertMlirHloToHlo. BuildHloFromMlirHlo will be used for MLIR legalization in the old bridge. PiperOrigin-RevId: 345388758 Change-Id: I3614e860583bfa83a6dc88c35b60bb801bb31462 --- tensorflow/compiler/mlir/xla/BUILD | 1 + .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 88 ++++++++++++++----- .../compiler/mlir/xla/mlir_hlo_to_hlo.h | 8 ++ .../mlir/xla/tests/translate/export.mlir | 17 +--- .../xla/tests/translate/export_replicas.mlir | 15 ++++ .../compiler/mlir/xla/xla_mlir_translate.cc | 70 +++++++++++++-- 6 files changed, 156 insertions(+), 43 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 2daa8a86d37..d122790af07 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -354,6 +354,7 @@ cc_library( ":mhlo_to_lhlo_with_xla", ":mlir_hlo_to_hlo", ":translate_cl_options", + ":type_to_shape", "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/mlir/hlo", diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 5c7a592df27..36aa31b0d34 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -497,11 +497,12 @@ class ConvertToHloModule { // Multiple return values are always converted to a tuple and returned as a // single value. explicit ConvertToHloModule( - mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module, xla::XlaBuilder& module_builder, + bool use_tuple_args, bool return_tuple, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, MlirToHloConversionOptions options) : module_(module), - module_builder_("main"), + module_builder_(module_builder), use_tuple_args_(use_tuple_args), return_tuple_(return_tuple), shape_representation_fn_(shape_representation_fn), @@ -547,14 +548,14 @@ class ConvertToHloModule { mlir::CallOp call_op, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering); - private: LogicalResult Lower( mlir::Operation* inst, bool is_entry_function, llvm::ArrayRef> ret_shardings, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, - xla::XlaComputation* result); + xla::XlaOp* return_value); + private: LogicalResult SetEntryTupleShapesAndLeafReplication( Block* block, const std::vector& entry_args_same_across_replicas, llvm::SmallVectorImpl* arg_shapes, @@ -569,7 +570,7 @@ class ConvertToHloModule { mlir::ModuleOp module_; // The top-level XlaBuilder. - xla::XlaBuilder module_builder_; + xla::XlaBuilder& module_builder_; // Map between function and lowered computation. FunctionLoweringMap lowered_computation_; @@ -1189,7 +1190,9 @@ LogicalResult ConvertToHloModule::Lower( llvm::ArrayRef> ret_shardings, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, - xla::XlaComputation* result) { + xla::XlaOp* return_value) { + *return_value = xla::XlaOp(); + // See MlirToHloConversionOptions for more about layouts. auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) { if (options_.propagate_layouts) { @@ -1255,7 +1258,6 @@ LogicalResult ConvertToHloModule::Lower( if (isa(inst)) { // Construct the return value for the function. If there are multiple // values returned, then create a tuple, else return value directly. - xla::XlaOp return_value; unsigned num_return_values = inst->getNumOperands(); if ((return_tuple_ && is_entry_function) || num_return_values > 1) { const bool has_ret_shardings = @@ -1291,24 +1293,16 @@ LogicalResult ConvertToHloModule::Lower( builder->SetSharding(sharding); } - return_value = xla::Tuple(builder, returns); + *return_value = xla::Tuple(builder, returns); builder->ClearSharding(); } else if (num_return_values == 1) { xla::XlaOp operand; if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst))) return failure(); - return_value = operand; + *return_value = operand; } - // Build the XlaComputation and check for failures. - auto computation_or = - return_value.valid() ? builder->Build(return_value) : builder->Build(); - if (!computation_or.ok()) { - inst->emitError(llvm::Twine(computation_or.status().error_message())); - return failure(); - } - *result = std::move(computation_or.ValueOrDie()); return success(); } @@ -1515,11 +1509,21 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } } + xla::XlaOp return_value; for (auto& inst : *block) if (failed(Lower(&inst, is_entry_function, ret_shardings, builder, - &lowering, result))) + &lowering, &return_value))) return failure(); + // Build the XlaComputation and check for failures. + auto computation_or = + return_value.valid() ? builder->Build(return_value) : builder->Build(); + if (!computation_or.ok()) { + block->back().emitError( + llvm::Twine(computation_or.status().error_message())); + return failure(); + } + *result = std::move(computation_or.ValueOrDie()); return success(); } @@ -1704,7 +1708,8 @@ Status ConvertRegionToComputation(mlir::Region* region, xla::XlaComputation* func, MlirToHloConversionOptions options) { mlir::ModuleOp module; - ConvertToHloModule converter(module, true, true, {}, options); + xla::XlaBuilder module_builder("main"); + ConvertToHloModule converter(module, module_builder, true, true, {}, options); if (failed(converter.LowerRegionAsComputation(region, func))) return tensorflow::errors::Internal( "failed to convert region to computation"); @@ -1717,14 +1722,55 @@ Status ConvertMlirHloToHlo( const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, MlirToHloConversionOptions options) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - ConvertToHloModule converter(module, use_tuple_args, return_tuple, - shape_representation_fn, options); + xla::XlaBuilder module_builder("main"); + ConvertToHloModule converter(module, module_builder, use_tuple_args, + return_tuple, shape_representation_fn, options); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); if (failed(AddDynamicParameterBindings( module, hlo_proto->mutable_hlo_module(), use_tuple_args))) return diag_handler.ConsumeStatus(); + return Status::OK(); +} + +Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + MlirToHloConversionOptions options) { + auto module = block.getParentOp()->getParentOfType(); + ConvertToHloModule converter(module, builder, + /*use_tuple_args=*/false, /*return_tuple=*/false, + /*shape_representation_fn=*/nullptr, options); + + ConvertToHloModule::ValueLoweringMap lowering; + if (xla_params.size() != block.getArguments().size()) + return tensorflow::errors::Internal( + "xla_params size != block arguments size"); + for (BlockArgument& arg : block.getArguments()) { + auto num = arg.getArgNumber(); + lowering[arg] = xla_params[num]; + } + + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + for (auto& inst : block) { + if (isa(inst)) { + returns.resize(inst.getNumOperands()); + for (OpOperand& ret : inst.getOpOperands()) { + unsigned index = ret.getOperandNumber(); + xla::XlaOp operand; + if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst))) + return diag_handler.ConsumeStatus(); + returns[index] = operand; + } + } else { + xla::XlaOp return_value; + if (failed(converter.Lower(&inst, /*is_entry_function=*/true, + /*ret_shardings=*/{}, &builder, &lowering, + &return_value))) + return diag_handler.ConsumeStatus(); + } + } return Status::OK(); } diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index a260a797354..a1c1cb5c7da 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -52,6 +52,14 @@ Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, shape_representation_fn = nullptr, MlirToHloConversionOptions options = {}); +// Transforms a Block into HLO, where the HLO is represented as calls into an +// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. +// xla_params are inputs to block. returns are the returned XlaOps. +Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + MlirToHloConversionOptions options = {}); + // Converts a region to a computation. It returns a standalone module that // contains the converted region as the entry computation. Status ConvertRegionToComputation(mlir::Region* region, diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 61686e13b26..b3d3603ae41 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1,4 +1,5 @@ // RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text-via-builder %s | FileCheck %s // CHECK: HloModule func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token { @@ -1004,22 +1005,6 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { // ----- -// Tests that the exported HLO module keeps parameter replication annotation. - -// CHECK: HloModule -func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> { - %0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> - return %0 : tensor<16x16xf32> -} - -// CHECK: ENTRY -// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0) -// CHECK-NOT: parameter_replication={true} -// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true} -// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]]) - -// ----- - // CHECK: HloModule func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> (tensor<2xf32>, tensor<2xf64>) { %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> (tensor<2xf32>) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir new file mode 100644 index 00000000000..40012f18c71 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir @@ -0,0 +1,15 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +// Tests that the exported HLO module keeps parameter replication annotation. + +// CHECK: HloModule +func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> { + %0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0) +// CHECK-NOT: parameter_replication={true} +// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true} +// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]]) diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 1be19de10c0..cc8c23ca124 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -124,16 +125,58 @@ static StatusOr> HloModuleFromProto( return HloModule::CreateFromProto(module_proto, module_config); } +// Wraps BuildHloFromMlirHlo to output an HloProto that's the same as +// ConvertMlirHloToHlo. +Status ConvertMlirHloToHloViaBuilder(mlir::ModuleOp module, + ::xla::HloProto* hlo_proto, + mlir::MlirToHloConversionOptions options) { + mlir::FuncOp main = module.lookupSymbol("main"); + mlir::Block& block = main.getRegion().front(); + xla::XlaBuilder builder("main"); + + // Create xla_params. + std::vector xla_params; + for (mlir::BlockArgument& arg : block.getArguments()) { + auto num = arg.getArgNumber(); + xla::Shape shape = xla::TypeToShape(arg.getType()); + XlaOp argop = + xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num)); + xla_params.push_back(argop); + } + + std::vector returns(1); + TF_RETURN_IF_ERROR( + mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns, options)); + + xla::XlaOp return_value; + if (returns.size() == 1) + return_value = returns[0]; + else if (returns.size() > 1) + return_value = xla::Tuple(&builder, returns); + + TF_ASSIGN_OR_RETURN( + xla::XlaComputation computation, + return_value.valid() ? builder.Build(return_value) : builder.Build()); + auto hlo_module = computation.proto(); + hlo_proto->mutable_hlo_module()->Swap(&hlo_module); + + return Status::OK(); +} + static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl( - mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) { + mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts, + bool via_builder) { if (!module) return mlir::failure(); HloProto hloProto; mlir::MlirToHloConversionOptions options; options.propagate_layouts = with_layouts; - Status status = mlir::ConvertMlirHloToHlo( - module, &hloProto, emit_use_tuple_arg, emit_return_tuple, - /*shape_representation_fn=*/nullptr, options); + Status status = + via_builder + ? ConvertMlirHloToHloViaBuilder(module, &hloProto, options) + : mlir::ConvertMlirHloToHlo( + module, &hloProto, emit_use_tuple_arg, emit_return_tuple, + /*shape_representation_fn=*/nullptr, options); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); @@ -167,13 +210,24 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl( static mlir::LogicalResult MlirHloToHloTextTranslateFunction( mlir::ModuleOp module, llvm::raw_ostream& output) { return MlirHloToHloTextTranslateFunctionImpl(module, output, - /*with_layouts=*/false); + /*with_layouts=*/false, + /*via_builder=*/false); } static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction( mlir::ModuleOp module, llvm::raw_ostream& output) { return MlirHloToHloTextTranslateFunctionImpl(module, output, - /*with_layouts=*/true); + /*with_layouts=*/true, + /*via_builder=*/false); +} + +// This converts MlirHlo to Hlo by first converting to XlaBuilder. +// This is useful for testing conversion to XlaBuilder. +static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/false, + /*via_builder=*/true); } } // namespace xla @@ -194,6 +248,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate( "mlir-hlo-to-hlo-text-with-layouts", xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects); +static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate( + "mlir-hlo-to-hlo-text-via-builder", + xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects); + static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);