Add BuildHloFromMlirHlo to mlir_hlo_to_hlo.h. Refactor internal functions to support both the new BuildHloFromMlirHlo and the existing ConvertMlirHloToHlo.
BuildHloFromMlirHlo will be used for MLIR legalization in the old bridge. PiperOrigin-RevId: 345388758 Change-Id: I3614e860583bfa83a6dc88c35b60bb801bb31462
This commit is contained in:
parent
4679dcf1ee
commit
a74f748fa4
@ -354,6 +354,7 @@ cc_library(
|
|||||||
":mhlo_to_lhlo_with_xla",
|
":mhlo_to_lhlo_with_xla",
|
||||||
":mlir_hlo_to_hlo",
|
":mlir_hlo_to_hlo",
|
||||||
":translate_cl_options",
|
":translate_cl_options",
|
||||||
|
":type_to_shape",
|
||||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||||
"//tensorflow/compiler/jit:xla_gpu_jit",
|
"//tensorflow/compiler/jit:xla_gpu_jit",
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
|
@ -497,11 +497,12 @@ class ConvertToHloModule {
|
|||||||
// Multiple return values are always converted to a tuple and returned as a
|
// Multiple return values are always converted to a tuple and returned as a
|
||||||
// single value.
|
// single value.
|
||||||
explicit ConvertToHloModule(
|
explicit ConvertToHloModule(
|
||||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
mlir::ModuleOp module, xla::XlaBuilder& module_builder,
|
||||||
|
bool use_tuple_args, bool return_tuple,
|
||||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
MlirToHloConversionOptions options)
|
MlirToHloConversionOptions options)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
module_builder_("main"),
|
module_builder_(module_builder),
|
||||||
use_tuple_args_(use_tuple_args),
|
use_tuple_args_(use_tuple_args),
|
||||||
return_tuple_(return_tuple),
|
return_tuple_(return_tuple),
|
||||||
shape_representation_fn_(shape_representation_fn),
|
shape_representation_fn_(shape_representation_fn),
|
||||||
@ -547,14 +548,14 @@ class ConvertToHloModule {
|
|||||||
mlir::CallOp call_op, xla::XlaBuilder* builder,
|
mlir::CallOp call_op, xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
||||||
|
|
||||||
private:
|
|
||||||
LogicalResult Lower(
|
LogicalResult Lower(
|
||||||
mlir::Operation* inst, bool is_entry_function,
|
mlir::Operation* inst, bool is_entry_function,
|
||||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaBuilder* builder,
|
xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||||
xla::XlaComputation* result);
|
xla::XlaOp* return_value);
|
||||||
|
|
||||||
|
private:
|
||||||
LogicalResult SetEntryTupleShapesAndLeafReplication(
|
LogicalResult SetEntryTupleShapesAndLeafReplication(
|
||||||
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
||||||
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
||||||
@ -569,7 +570,7 @@ class ConvertToHloModule {
|
|||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
|
|
||||||
// The top-level XlaBuilder.
|
// The top-level XlaBuilder.
|
||||||
xla::XlaBuilder module_builder_;
|
xla::XlaBuilder& module_builder_;
|
||||||
|
|
||||||
// Map between function and lowered computation.
|
// Map between function and lowered computation.
|
||||||
FunctionLoweringMap lowered_computation_;
|
FunctionLoweringMap lowered_computation_;
|
||||||
@ -1189,7 +1190,9 @@ LogicalResult ConvertToHloModule::Lower(
|
|||||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaBuilder* builder,
|
xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||||
xla::XlaComputation* result) {
|
xla::XlaOp* return_value) {
|
||||||
|
*return_value = xla::XlaOp();
|
||||||
|
|
||||||
// See MlirToHloConversionOptions for more about layouts.
|
// See MlirToHloConversionOptions for more about layouts.
|
||||||
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
|
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||||
if (options_.propagate_layouts) {
|
if (options_.propagate_layouts) {
|
||||||
@ -1255,7 +1258,6 @@ LogicalResult ConvertToHloModule::Lower(
|
|||||||
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||||
// Construct the return value for the function. If there are multiple
|
// Construct the return value for the function. If there are multiple
|
||||||
// values returned, then create a tuple, else return value directly.
|
// values returned, then create a tuple, else return value directly.
|
||||||
xla::XlaOp return_value;
|
|
||||||
unsigned num_return_values = inst->getNumOperands();
|
unsigned num_return_values = inst->getNumOperands();
|
||||||
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
|
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
|
||||||
const bool has_ret_shardings =
|
const bool has_ret_shardings =
|
||||||
@ -1291,24 +1293,16 @@ LogicalResult ConvertToHloModule::Lower(
|
|||||||
builder->SetSharding(sharding);
|
builder->SetSharding(sharding);
|
||||||
}
|
}
|
||||||
|
|
||||||
return_value = xla::Tuple(builder, returns);
|
*return_value = xla::Tuple(builder, returns);
|
||||||
builder->ClearSharding();
|
builder->ClearSharding();
|
||||||
} else if (num_return_values == 1) {
|
} else if (num_return_values == 1) {
|
||||||
xla::XlaOp operand;
|
xla::XlaOp operand;
|
||||||
if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
|
if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return_value = operand;
|
*return_value = operand;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the XlaComputation and check for failures.
|
|
||||||
auto computation_or =
|
|
||||||
return_value.valid() ? builder->Build(return_value) : builder->Build();
|
|
||||||
if (!computation_or.ok()) {
|
|
||||||
inst->emitError(llvm::Twine(computation_or.status().error_message()));
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
*result = std::move(computation_or.ValueOrDie());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1515,11 +1509,21 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::XlaOp return_value;
|
||||||
for (auto& inst : *block)
|
for (auto& inst : *block)
|
||||||
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
|
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
|
||||||
&lowering, result)))
|
&lowering, &return_value)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// Build the XlaComputation and check for failures.
|
||||||
|
auto computation_or =
|
||||||
|
return_value.valid() ? builder->Build(return_value) : builder->Build();
|
||||||
|
if (!computation_or.ok()) {
|
||||||
|
block->back().emitError(
|
||||||
|
llvm::Twine(computation_or.status().error_message()));
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
*result = std::move(computation_or.ValueOrDie());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1704,7 +1708,8 @@ Status ConvertRegionToComputation(mlir::Region* region,
|
|||||||
xla::XlaComputation* func,
|
xla::XlaComputation* func,
|
||||||
MlirToHloConversionOptions options) {
|
MlirToHloConversionOptions options) {
|
||||||
mlir::ModuleOp module;
|
mlir::ModuleOp module;
|
||||||
ConvertToHloModule converter(module, true, true, {}, options);
|
xla::XlaBuilder module_builder("main");
|
||||||
|
ConvertToHloModule converter(module, module_builder, true, true, {}, options);
|
||||||
if (failed(converter.LowerRegionAsComputation(region, func)))
|
if (failed(converter.LowerRegionAsComputation(region, func)))
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
"failed to convert region to computation");
|
"failed to convert region to computation");
|
||||||
@ -1717,14 +1722,55 @@ Status ConvertMlirHloToHlo(
|
|||||||
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
MlirToHloConversionOptions options) {
|
MlirToHloConversionOptions options) {
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
xla::XlaBuilder module_builder("main");
|
||||||
shape_representation_fn, options);
|
ConvertToHloModule converter(module, module_builder, use_tuple_args,
|
||||||
|
return_tuple, shape_representation_fn, options);
|
||||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||||
auto hlo_module = converter.ConsumeMainProto();
|
auto hlo_module = converter.ConsumeMainProto();
|
||||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||||
if (failed(AddDynamicParameterBindings(
|
if (failed(AddDynamicParameterBindings(
|
||||||
module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
|
module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
|
||||||
return diag_handler.ConsumeStatus();
|
return diag_handler.ConsumeStatus();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
|
||||||
|
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||||
|
std::vector<xla::XlaOp>& returns,
|
||||||
|
MlirToHloConversionOptions options) {
|
||||||
|
auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
|
||||||
|
ConvertToHloModule converter(module, builder,
|
||||||
|
/*use_tuple_args=*/false, /*return_tuple=*/false,
|
||||||
|
/*shape_representation_fn=*/nullptr, options);
|
||||||
|
|
||||||
|
ConvertToHloModule::ValueLoweringMap lowering;
|
||||||
|
if (xla_params.size() != block.getArguments().size())
|
||||||
|
return tensorflow::errors::Internal(
|
||||||
|
"xla_params size != block arguments size");
|
||||||
|
for (BlockArgument& arg : block.getArguments()) {
|
||||||
|
auto num = arg.getArgNumber();
|
||||||
|
lowering[arg] = xla_params[num];
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
|
for (auto& inst : block) {
|
||||||
|
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||||
|
returns.resize(inst.getNumOperands());
|
||||||
|
for (OpOperand& ret : inst.getOpOperands()) {
|
||||||
|
unsigned index = ret.getOperandNumber();
|
||||||
|
xla::XlaOp operand;
|
||||||
|
if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
|
||||||
|
return diag_handler.ConsumeStatus();
|
||||||
|
returns[index] = operand;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xla::XlaOp return_value;
|
||||||
|
if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
|
||||||
|
/*ret_shardings=*/{}, &builder, &lowering,
|
||||||
|
&return_value)))
|
||||||
|
return diag_handler.ConsumeStatus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,14 @@ Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
|||||||
shape_representation_fn = nullptr,
|
shape_representation_fn = nullptr,
|
||||||
MlirToHloConversionOptions options = {});
|
MlirToHloConversionOptions options = {});
|
||||||
|
|
||||||
|
// Transforms a Block into HLO, where the HLO is represented as calls into an
|
||||||
|
// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp.
|
||||||
|
// xla_params are inputs to block. returns are the returned XlaOps.
|
||||||
|
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
|
||||||
|
llvm::ArrayRef<xla::XlaOp> xla_params,
|
||||||
|
std::vector<xla::XlaOp>& returns,
|
||||||
|
MlirToHloConversionOptions options = {});
|
||||||
|
|
||||||
// Converts a region to a computation. It returns a standalone module that
|
// Converts a region to a computation. It returns a standalone module that
|
||||||
// contains the converted region as the entry computation.
|
// contains the converted region as the entry computation.
|
||||||
Status ConvertRegionToComputation(mlir::Region* region,
|
Status ConvertRegionToComputation(mlir::Region* region,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||||
|
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text-via-builder %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK: HloModule
|
// CHECK: HloModule
|
||||||
func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
|
func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
|
||||||
@ -1004,22 +1005,6 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Tests that the exported HLO module keeps parameter replication annotation.
|
|
||||||
|
|
||||||
// CHECK: HloModule
|
|
||||||
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
|
|
||||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
|
|
||||||
return %0 : tensor<16x16xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: ENTRY
|
|
||||||
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
|
|
||||||
// CHECK-NOT: parameter_replication={true}
|
|
||||||
// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
|
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK: HloModule
|
// CHECK: HloModule
|
||||||
func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
|
func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
|
||||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||||
|
|
||||||
|
// Tests that the exported HLO module keeps parameter replication annotation.
|
||||||
|
|
||||||
|
// CHECK: HloModule
|
||||||
|
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
|
||||||
|
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||||
|
return %0 : tensor<16x16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
|
||||||
|
// CHECK-NOT: parameter_replication={true}
|
||||||
|
// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
|
||||||
|
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
|
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
|
||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
@ -124,16 +125,58 @@ static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
|
|||||||
return HloModule::CreateFromProto(module_proto, module_config);
|
return HloModule::CreateFromProto(module_proto, module_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wraps BuildHloFromMlirHlo to output an HloProto that's the same as
|
||||||
|
// ConvertMlirHloToHlo.
|
||||||
|
Status ConvertMlirHloToHloViaBuilder(mlir::ModuleOp module,
|
||||||
|
::xla::HloProto* hlo_proto,
|
||||||
|
mlir::MlirToHloConversionOptions options) {
|
||||||
|
mlir::FuncOp main = module.lookupSymbol<mlir::FuncOp>("main");
|
||||||
|
mlir::Block& block = main.getRegion().front();
|
||||||
|
xla::XlaBuilder builder("main");
|
||||||
|
|
||||||
|
// Create xla_params.
|
||||||
|
std::vector<xla::XlaOp> xla_params;
|
||||||
|
for (mlir::BlockArgument& arg : block.getArguments()) {
|
||||||
|
auto num = arg.getArgNumber();
|
||||||
|
xla::Shape shape = xla::TypeToShape(arg.getType());
|
||||||
|
XlaOp argop =
|
||||||
|
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
|
||||||
|
xla_params.push_back(argop);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<xla::XlaOp> returns(1);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns, options));
|
||||||
|
|
||||||
|
xla::XlaOp return_value;
|
||||||
|
if (returns.size() == 1)
|
||||||
|
return_value = returns[0];
|
||||||
|
else if (returns.size() > 1)
|
||||||
|
return_value = xla::Tuple(&builder, returns);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
xla::XlaComputation computation,
|
||||||
|
return_value.valid() ? builder.Build(return_value) : builder.Build());
|
||||||
|
auto hlo_module = computation.proto();
|
||||||
|
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
||||||
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) {
|
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts,
|
||||||
|
bool via_builder) {
|
||||||
if (!module) return mlir::failure();
|
if (!module) return mlir::failure();
|
||||||
|
|
||||||
HloProto hloProto;
|
HloProto hloProto;
|
||||||
mlir::MlirToHloConversionOptions options;
|
mlir::MlirToHloConversionOptions options;
|
||||||
options.propagate_layouts = with_layouts;
|
options.propagate_layouts = with_layouts;
|
||||||
Status status = mlir::ConvertMlirHloToHlo(
|
Status status =
|
||||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
via_builder
|
||||||
/*shape_representation_fn=*/nullptr, options);
|
? ConvertMlirHloToHloViaBuilder(module, &hloProto, options)
|
||||||
|
: mlir::ConvertMlirHloToHlo(
|
||||||
|
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
||||||
|
/*shape_representation_fn=*/nullptr, options);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Module conversion failed: " << status;
|
LOG(ERROR) << "Module conversion failed: " << status;
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@ -167,13 +210,24 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
|||||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||||
/*with_layouts=*/false);
|
/*with_layouts=*/false,
|
||||||
|
/*via_builder=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction(
|
static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction(
|
||||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||||
/*with_layouts=*/true);
|
/*with_layouts=*/true,
|
||||||
|
/*via_builder=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This converts MlirHlo to Hlo by first converting to XlaBuilder.
|
||||||
|
// This is useful for testing conversion to XlaBuilder.
|
||||||
|
static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction(
|
||||||
|
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||||
|
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||||
|
/*with_layouts=*/false,
|
||||||
|
/*via_builder=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
@ -194,6 +248,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate(
|
|||||||
"mlir-hlo-to-hlo-text-with-layouts",
|
"mlir-hlo-to-hlo-text-with-layouts",
|
||||||
xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects);
|
xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects);
|
||||||
|
|
||||||
|
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate(
|
||||||
|
"mlir-hlo-to-hlo-text-via-builder",
|
||||||
|
xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects);
|
||||||
|
|
||||||
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
|
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
|
||||||
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
|
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user