Add BuildHloFromMlirHlo to mlir_hlo_to_hlo.h. Refactor internal functions to support both the new BuildHloFromMlirHlo and the existing ConvertMlirHloToHlo.
BuildHloFromMlirHlo will be used for MLIR legalization in the old bridge. PiperOrigin-RevId: 345388758 Change-Id: I3614e860583bfa83a6dc88c35b60bb801bb31462
This commit is contained in:
parent
4679dcf1ee
commit
a74f748fa4
@ -354,6 +354,7 @@ cc_library(
|
||||
":mhlo_to_lhlo_with_xla",
|
||||
":mlir_hlo_to_hlo",
|
||||
":translate_cl_options",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//tensorflow/compiler/jit:xla_gpu_jit",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
|
@ -497,11 +497,12 @@ class ConvertToHloModule {
|
||||
// Multiple return values are always converted to a tuple and returned as a
|
||||
// single value.
|
||||
explicit ConvertToHloModule(
|
||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||
mlir::ModuleOp module, xla::XlaBuilder& module_builder,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
MlirToHloConversionOptions options)
|
||||
: module_(module),
|
||||
module_builder_("main"),
|
||||
module_builder_(module_builder),
|
||||
use_tuple_args_(use_tuple_args),
|
||||
return_tuple_(return_tuple),
|
||||
shape_representation_fn_(shape_representation_fn),
|
||||
@ -547,14 +548,14 @@ class ConvertToHloModule {
|
||||
mlir::CallOp call_op, xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
||||
|
||||
private:
|
||||
LogicalResult Lower(
|
||||
mlir::Operation* inst, bool is_entry_function,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result);
|
||||
xla::XlaOp* return_value);
|
||||
|
||||
private:
|
||||
LogicalResult SetEntryTupleShapesAndLeafReplication(
|
||||
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
||||
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
||||
@ -569,7 +570,7 @@ class ConvertToHloModule {
|
||||
mlir::ModuleOp module_;
|
||||
|
||||
// The top-level XlaBuilder.
|
||||
xla::XlaBuilder module_builder_;
|
||||
xla::XlaBuilder& module_builder_;
|
||||
|
||||
// Map between function and lowered computation.
|
||||
FunctionLoweringMap lowered_computation_;
|
||||
@ -1189,7 +1190,9 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result) {
|
||||
xla::XlaOp* return_value) {
|
||||
*return_value = xla::XlaOp();
|
||||
|
||||
// See MlirToHloConversionOptions for more about layouts.
|
||||
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||
if (options_.propagate_layouts) {
|
||||
@ -1255,7 +1258,6 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||
// Construct the return value for the function. If there are multiple
|
||||
// values returned, then create a tuple, else return value directly.
|
||||
xla::XlaOp return_value;
|
||||
unsigned num_return_values = inst->getNumOperands();
|
||||
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
|
||||
const bool has_ret_shardings =
|
||||
@ -1291,24 +1293,16 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
builder->SetSharding(sharding);
|
||||
}
|
||||
|
||||
return_value = xla::Tuple(builder, returns);
|
||||
*return_value = xla::Tuple(builder, returns);
|
||||
builder->ClearSharding();
|
||||
} else if (num_return_values == 1) {
|
||||
xla::XlaOp operand;
|
||||
if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
|
||||
return failure();
|
||||
|
||||
return_value = operand;
|
||||
*return_value = operand;
|
||||
}
|
||||
|
||||
// Build the XlaComputation and check for failures.
|
||||
auto computation_or =
|
||||
return_value.valid() ? builder->Build(return_value) : builder->Build();
|
||||
if (!computation_or.ok()) {
|
||||
inst->emitError(llvm::Twine(computation_or.status().error_message()));
|
||||
return failure();
|
||||
}
|
||||
*result = std::move(computation_or.ValueOrDie());
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1515,11 +1509,21 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp return_value;
|
||||
for (auto& inst : *block)
|
||||
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
|
||||
&lowering, result)))
|
||||
&lowering, &return_value)))
|
||||
return failure();
|
||||
|
||||
// Build the XlaComputation and check for failures.
|
||||
auto computation_or =
|
||||
return_value.valid() ? builder->Build(return_value) : builder->Build();
|
||||
if (!computation_or.ok()) {
|
||||
block->back().emitError(
|
||||
llvm::Twine(computation_or.status().error_message()));
|
||||
return failure();
|
||||
}
|
||||
*result = std::move(computation_or.ValueOrDie());
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1704,7 +1708,8 @@ Status ConvertRegionToComputation(mlir::Region* region,
|
||||
xla::XlaComputation* func,
|
||||
MlirToHloConversionOptions options) {
|
||||
mlir::ModuleOp module;
|
||||
ConvertToHloModule converter(module, true, true, {}, options);
|
||||
xla::XlaBuilder module_builder("main");
|
||||
ConvertToHloModule converter(module, module_builder, true, true, {}, options);
|
||||
if (failed(converter.LowerRegionAsComputation(region, func)))
|
||||
return tensorflow::errors::Internal(
|
||||
"failed to convert region to computation");
|
||||
@ -1717,14 +1722,55 @@ Status ConvertMlirHloToHlo(
|
||||
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
MlirToHloConversionOptions options) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||
shape_representation_fn, options);
|
||||
xla::XlaBuilder module_builder("main");
|
||||
ConvertToHloModule converter(module, module_builder, use_tuple_args,
|
||||
return_tuple, shape_representation_fn, options);
|
||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||
auto hlo_module = converter.ConsumeMainProto();
|
||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||
if (failed(AddDynamicParameterBindings(
|
||||
module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
|
||||
return diag_handler.ConsumeStatus();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
MlirToHloConversionOptions options) {
|
||||
auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
|
||||
ConvertToHloModule converter(module, builder,
|
||||
/*use_tuple_args=*/false, /*return_tuple=*/false,
|
||||
/*shape_representation_fn=*/nullptr, options);
|
||||
|
||||
ConvertToHloModule::ValueLoweringMap lowering;
|
||||
if (xla_params.size() != block.getArguments().size())
|
||||
return tensorflow::errors::Internal(
|
||||
"xla_params size != block arguments size");
|
||||
for (BlockArgument& arg : block.getArguments()) {
|
||||
auto num = arg.getArgNumber();
|
||||
lowering[arg] = xla_params[num];
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
for (auto& inst : block) {
|
||||
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||
returns.resize(inst.getNumOperands());
|
||||
for (OpOperand& ret : inst.getOpOperands()) {
|
||||
unsigned index = ret.getOperandNumber();
|
||||
xla::XlaOp operand;
|
||||
if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
|
||||
return diag_handler.ConsumeStatus();
|
||||
returns[index] = operand;
|
||||
}
|
||||
} else {
|
||||
xla::XlaOp return_value;
|
||||
if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
|
||||
/*ret_shardings=*/{}, &builder, &lowering,
|
||||
&return_value)))
|
||||
return diag_handler.ConsumeStatus();
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -52,6 +52,14 @@ Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
||||
shape_representation_fn = nullptr,
|
||||
MlirToHloConversionOptions options = {});
|
||||
|
||||
// Transforms a Block into HLO, where the HLO is represented as calls into an
|
||||
// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp.
|
||||
// xla_params are inputs to block. returns are the returned XlaOps.
|
||||
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
|
||||
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||
std::vector<xla::XlaOp>& returns,
|
||||
MlirToHloConversionOptions options = {});
|
||||
|
||||
// Converts a region to a computation. It returns a standalone module that
|
||||
// contains the converted region as the entry computation.
|
||||
Status ConvertRegionToComputation(mlir::Region* region,
|
||||
|
@ -1,4 +1,5 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text-via-builder %s | FileCheck %s
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
|
||||
@ -1004,22 +1005,6 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the exported HLO module keeps parameter replication annotation.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
|
||||
// CHECK-NOT: parameter_replication={true}
|
||||
// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
@ -0,0 +1,15 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// Tests that the exported HLO module keeps parameter replication annotation.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
|
||||
// CHECK-NOT: parameter_replication={true}
|
||||
// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
@ -124,16 +125,58 @@ static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
|
||||
return HloModule::CreateFromProto(module_proto, module_config);
|
||||
}
|
||||
|
||||
// Wraps BuildHloFromMlirHlo to output an HloProto that's the same as
|
||||
// ConvertMlirHloToHlo.
|
||||
Status ConvertMlirHloToHloViaBuilder(mlir::ModuleOp module,
|
||||
::xla::HloProto* hlo_proto,
|
||||
mlir::MlirToHloConversionOptions options) {
|
||||
mlir::FuncOp main = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
mlir::Block& block = main.getRegion().front();
|
||||
xla::XlaBuilder builder("main");
|
||||
|
||||
// Create xla_params.
|
||||
std::vector<xla::XlaOp> xla_params;
|
||||
for (mlir::BlockArgument& arg : block.getArguments()) {
|
||||
auto num = arg.getArgNumber();
|
||||
xla::Shape shape = xla::TypeToShape(arg.getType());
|
||||
XlaOp argop =
|
||||
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
|
||||
xla_params.push_back(argop);
|
||||
}
|
||||
|
||||
std::vector<xla::XlaOp> returns(1);
|
||||
TF_RETURN_IF_ERROR(
|
||||
mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns, options));
|
||||
|
||||
xla::XlaOp return_value;
|
||||
if (returns.size() == 1)
|
||||
return_value = returns[0];
|
||||
else if (returns.size() > 1)
|
||||
return_value = xla::Tuple(&builder, returns);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::XlaComputation computation,
|
||||
return_value.valid() ? builder.Build(return_value) : builder.Build());
|
||||
auto hlo_module = computation.proto();
|
||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) {
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts,
|
||||
bool via_builder) {
|
||||
if (!module) return mlir::failure();
|
||||
|
||||
HloProto hloProto;
|
||||
mlir::MlirToHloConversionOptions options;
|
||||
options.propagate_layouts = with_layouts;
|
||||
Status status = mlir::ConvertMlirHloToHlo(
|
||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
||||
/*shape_representation_fn=*/nullptr, options);
|
||||
Status status =
|
||||
via_builder
|
||||
? ConvertMlirHloToHloViaBuilder(module, &hloProto, options)
|
||||
: mlir::ConvertMlirHloToHlo(
|
||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
||||
/*shape_representation_fn=*/nullptr, options);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Module conversion failed: " << status;
|
||||
return mlir::failure();
|
||||
@ -167,13 +210,24 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||
/*with_layouts=*/false);
|
||||
/*with_layouts=*/false,
|
||||
/*via_builder=*/false);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||
/*with_layouts=*/true);
|
||||
/*with_layouts=*/true,
|
||||
/*via_builder=*/false);
|
||||
}
|
||||
|
||||
// This converts MlirHlo to Hlo by first converting to XlaBuilder.
|
||||
// This is useful for testing conversion to XlaBuilder.
|
||||
static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||
/*with_layouts=*/false,
|
||||
/*via_builder=*/true);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
@ -194,6 +248,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate(
|
||||
"mlir-hlo-to-hlo-text-with-layouts",
|
||||
xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects);
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate(
|
||||
"mlir-hlo-to-hlo-text-via-builder",
|
||||
xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects);
|
||||
|
||||
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
|
||||
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user