Add BuildHloFromMlirHlo to mlir_hlo_to_hlo.h. Refactor internal functions to support both the new BuildHloFromMlirHlo and the existing ConvertMlirHloToHlo.

BuildHloFromMlirHlo will be used for MLIR legalization in the old bridge.

PiperOrigin-RevId: 345388758
Change-Id: I3614e860583bfa83a6dc88c35b60bb801bb31462
This commit is contained in:
Michael Delorimier 2020-12-02 23:26:35 -08:00 committed by TensorFlower Gardener
parent 4679dcf1ee
commit a74f748fa4
6 changed files with 156 additions and 43 deletions

View File

@ -354,6 +354,7 @@ cc_library(
":mhlo_to_lhlo_with_xla",
":mlir_hlo_to_hlo",
":translate_cl_options",
":type_to_shape",
"//tensorflow/compiler/jit:xla_cpu_jit",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/mlir/hlo",

View File

@ -497,11 +497,12 @@ class ConvertToHloModule {
// Multiple return values are always converted to a tuple and returned as a
// single value.
explicit ConvertToHloModule(
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
mlir::ModuleOp module, xla::XlaBuilder& module_builder,
bool use_tuple_args, bool return_tuple,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
MlirToHloConversionOptions options)
: module_(module),
module_builder_("main"),
module_builder_(module_builder),
use_tuple_args_(use_tuple_args),
return_tuple_(return_tuple),
shape_representation_fn_(shape_representation_fn),
@ -547,14 +548,14 @@ class ConvertToHloModule {
mlir::CallOp call_op, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering);
private:
LogicalResult Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaComputation* result);
xla::XlaOp* return_value);
private:
LogicalResult SetEntryTupleShapesAndLeafReplication(
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
@ -569,7 +570,7 @@ class ConvertToHloModule {
mlir::ModuleOp module_;
// The top-level XlaBuilder.
xla::XlaBuilder module_builder_;
xla::XlaBuilder& module_builder_;
// Map between function and lowered computation.
FunctionLoweringMap lowered_computation_;
@ -1189,7 +1190,9 @@ LogicalResult ConvertToHloModule::Lower(
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaComputation* result) {
xla::XlaOp* return_value) {
*return_value = xla::XlaOp();
// See MlirToHloConversionOptions for more about layouts.
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
if (options_.propagate_layouts) {
@ -1255,7 +1258,6 @@ LogicalResult ConvertToHloModule::Lower(
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
// Construct the return value for the function. If there are multiple
// values returned, then create a tuple, else return value directly.
xla::XlaOp return_value;
unsigned num_return_values = inst->getNumOperands();
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
const bool has_ret_shardings =
@ -1291,24 +1293,16 @@ LogicalResult ConvertToHloModule::Lower(
builder->SetSharding(sharding);
}
return_value = xla::Tuple(builder, returns);
*return_value = xla::Tuple(builder, returns);
builder->ClearSharding();
} else if (num_return_values == 1) {
xla::XlaOp operand;
if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
return failure();
return_value = operand;
*return_value = operand;
}
// Build the XlaComputation and check for failures.
auto computation_or =
return_value.valid() ? builder->Build(return_value) : builder->Build();
if (!computation_or.ok()) {
inst->emitError(llvm::Twine(computation_or.status().error_message()));
return failure();
}
*result = std::move(computation_or.ValueOrDie());
return success();
}
@ -1515,11 +1509,21 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
}
}
xla::XlaOp return_value;
for (auto& inst : *block)
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
&lowering, result)))
&lowering, &return_value)))
return failure();
// Build the XlaComputation and check for failures.
auto computation_or =
return_value.valid() ? builder->Build(return_value) : builder->Build();
if (!computation_or.ok()) {
block->back().emitError(
llvm::Twine(computation_or.status().error_message()));
return failure();
}
*result = std::move(computation_or.ValueOrDie());
return success();
}
@ -1704,7 +1708,8 @@ Status ConvertRegionToComputation(mlir::Region* region,
xla::XlaComputation* func,
MlirToHloConversionOptions options) {
mlir::ModuleOp module;
ConvertToHloModule converter(module, true, true, {}, options);
xla::XlaBuilder module_builder("main");
ConvertToHloModule converter(module, module_builder, true, true, {}, options);
if (failed(converter.LowerRegionAsComputation(region, func)))
return tensorflow::errors::Internal(
"failed to convert region to computation");
@ -1717,14 +1722,55 @@ Status ConvertMlirHloToHlo(
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
MlirToHloConversionOptions options) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
shape_representation_fn, options);
xla::XlaBuilder module_builder("main");
ConvertToHloModule converter(module, module_builder, use_tuple_args,
return_tuple, shape_representation_fn, options);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
if (failed(AddDynamicParameterBindings(
module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
return diag_handler.ConsumeStatus();
return Status::OK();
}
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
MlirToHloConversionOptions options) {
auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
ConvertToHloModule converter(module, builder,
/*use_tuple_args=*/false, /*return_tuple=*/false,
/*shape_representation_fn=*/nullptr, options);
ConvertToHloModule::ValueLoweringMap lowering;
if (xla_params.size() != block.getArguments().size())
return tensorflow::errors::Internal(
"xla_params size != block arguments size");
for (BlockArgument& arg : block.getArguments()) {
auto num = arg.getArgNumber();
lowering[arg] = xla_params[num];
}
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
for (auto& inst : block) {
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
returns.resize(inst.getNumOperands());
for (OpOperand& ret : inst.getOpOperands()) {
unsigned index = ret.getOperandNumber();
xla::XlaOp operand;
if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
return diag_handler.ConsumeStatus();
returns[index] = operand;
}
} else {
xla::XlaOp return_value;
if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
/*ret_shardings=*/{}, &builder, &lowering,
&return_value)))
return diag_handler.ConsumeStatus();
}
}
return Status::OK();
}

View File

@ -52,6 +52,14 @@ Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
shape_representation_fn = nullptr,
MlirToHloConversionOptions options = {});
// Transforms a Block into HLO, where the HLO is represented as calls into an
// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp.
// xla_params are inputs to block. returns are the returned XlaOps.
Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
MlirToHloConversionOptions options = {});
// Converts a region to a computation. It returns a standalone module that
// contains the converted region as the entry computation.
Status ConvertRegionToComputation(mlir::Region* region,

View File

@ -1,4 +1,5 @@
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text-via-builder %s | FileCheck %s
// CHECK: HloModule
func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
@ -1004,22 +1005,6 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
// -----
// Tests that the exported HLO module keeps parameter replication annotation.
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
// CHECK-NOT: parameter_replication={true}
// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
// -----
// CHECK: HloModule
func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)

View File

@ -0,0 +1,15 @@
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
// Tests that the exported HLO module keeps parameter replication annotation.
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0)
// CHECK-NOT: parameter_replication={true}
// CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@ -124,16 +125,58 @@ static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
return HloModule::CreateFromProto(module_proto, module_config);
}
// Wraps BuildHloFromMlirHlo to output an HloProto that's the same as
// ConvertMlirHloToHlo.
Status ConvertMlirHloToHloViaBuilder(mlir::ModuleOp module,
::xla::HloProto* hlo_proto,
mlir::MlirToHloConversionOptions options) {
mlir::FuncOp main = module.lookupSymbol<mlir::FuncOp>("main");
mlir::Block& block = main.getRegion().front();
xla::XlaBuilder builder("main");
// Create xla_params.
std::vector<xla::XlaOp> xla_params;
for (mlir::BlockArgument& arg : block.getArguments()) {
auto num = arg.getArgNumber();
xla::Shape shape = xla::TypeToShape(arg.getType());
XlaOp argop =
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
xla_params.push_back(argop);
}
std::vector<xla::XlaOp> returns(1);
TF_RETURN_IF_ERROR(
mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns, options));
xla::XlaOp return_value;
if (returns.size() == 1)
return_value = returns[0];
else if (returns.size() > 1)
return_value = xla::Tuple(&builder, returns);
TF_ASSIGN_OR_RETURN(
xla::XlaComputation computation,
return_value.valid() ? builder.Build(return_value) : builder.Build());
auto hlo_module = computation.proto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
return Status::OK();
}
static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) {
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts,
bool via_builder) {
if (!module) return mlir::failure();
HloProto hloProto;
mlir::MlirToHloConversionOptions options;
options.propagate_layouts = with_layouts;
Status status = mlir::ConvertMlirHloToHlo(
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
/*shape_representation_fn=*/nullptr, options);
Status status =
via_builder
? ConvertMlirHloToHloViaBuilder(module, &hloProto, options)
: mlir::ConvertMlirHloToHlo(
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
/*shape_representation_fn=*/nullptr, options);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
@ -167,13 +210,24 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
mlir::ModuleOp module, llvm::raw_ostream& output) {
return MlirHloToHloTextTranslateFunctionImpl(module, output,
/*with_layouts=*/false);
/*with_layouts=*/false,
/*via_builder=*/false);
}
static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction(
mlir::ModuleOp module, llvm::raw_ostream& output) {
return MlirHloToHloTextTranslateFunctionImpl(module, output,
/*with_layouts=*/true);
/*with_layouts=*/true,
/*via_builder=*/false);
}
// This converts MlirHlo to Hlo by first converting to XlaBuilder.
// This is useful for testing conversion to XlaBuilder.
static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction(
mlir::ModuleOp module, llvm::raw_ostream& output) {
return MlirHloToHloTextTranslateFunctionImpl(module, output,
/*with_layouts=*/false,
/*via_builder=*/true);
}
} // namespace xla
@ -194,6 +248,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate(
"mlir-hlo-to-hlo-text-with-layouts",
xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects);
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate(
"mlir-hlo-to-hlo-text-via-builder",
xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects);
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);