Add option to emit args and results as tuples
Use this to either group all arguments as single input tuple or always return a tuple (even if only result). Add testing flags in translation registration. PiperOrigin-RevId: 267182876
This commit is contained in:
parent
7b8e672d22
commit
357615e12a
@ -137,8 +137,12 @@ class ConvertToHloModule {
|
||||
using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
|
||||
using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
|
||||
|
||||
explicit ConvertToHloModule(mlir::ModuleOp module)
|
||||
: module_(module), module_builder_("main") {}
|
||||
explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
|
||||
bool always_return_tuple)
|
||||
: module_(module),
|
||||
module_builder_("main"),
|
||||
use_tuple_args_(use_tuple_args),
|
||||
always_return_tuple_(always_return_tuple) {}
|
||||
|
||||
// Perform the lowering to XLA. This function returns failure if an error was
|
||||
// encountered.
|
||||
@ -160,6 +164,9 @@ class ConvertToHloModule {
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
||||
|
||||
// The module being lowered.
|
||||
mlir::ModuleOp module_;
|
||||
|
||||
@ -168,11 +175,17 @@ class ConvertToHloModule {
|
||||
|
||||
// Map between function and lowered computation.
|
||||
FunctionLoweringMap lowered_computation_;
|
||||
|
||||
// Whether the entry function should take a single tuple as input.
|
||||
bool use_tuple_args_;
|
||||
|
||||
// Whether to always return a tuple.
|
||||
bool always_return_tuple_;
|
||||
};
|
||||
|
||||
LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::FunctionLoweringMap* function_lowering,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering) {
|
||||
LogicalResult ConvertToHloModule::Lower(
|
||||
mlir::Operation* inst, xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering) {
|
||||
if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success();
|
||||
|
||||
// TODO(riverriddle) We currently don't support lowering constant operations.
|
||||
@ -187,7 +200,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
|
||||
// values returned, then create a tuple, else return value directly.
|
||||
xla::XlaOp return_value;
|
||||
unsigned num_return_values = ret.getNumOperands();
|
||||
if (num_return_values > 1) {
|
||||
if (always_return_tuple_ || num_return_values > 1) {
|
||||
std::vector<xla::XlaOp> returns(num_return_values);
|
||||
for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) {
|
||||
returns[i] = value_map[ret.getOperand(i)];
|
||||
@ -205,7 +218,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
|
||||
return failure();
|
||||
}
|
||||
auto f = inst->getParentOfType<mlir::FuncOp>();
|
||||
(*function_lowering)[f] = std::move(computation_or.ValueOrDie());
|
||||
lowered_computation_[f] = std::move(computation_or.ValueOrDie());
|
||||
return success();
|
||||
}
|
||||
inst->emitError("unable to lower operation of type '" +
|
||||
@ -228,28 +241,42 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||
// Mapping from the Value to lowered XlaOp. The code below lowers in
|
||||
// program order and will fail if an operand is unseen. This can be improved.
|
||||
ValueLoweringMap lowering;
|
||||
for (auto& bb : f) {
|
||||
int num = 0;
|
||||
for (auto& arg : bb.getArguments()) {
|
||||
auto& bb = f.front();
|
||||
|
||||
// If using tuples as input, then there is only one input
|
||||
// parameter that is a tuple.
|
||||
if (use_tuple_args_) {
|
||||
std::vector<xla::Shape> arg_shapes;
|
||||
arg_shapes.reserve(bb.getNumArguments());
|
||||
for (auto& arg : bb.getArguments())
|
||||
arg_shapes.push_back(xla::TypeToShape(arg->getType()));
|
||||
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
||||
auto tuple = xla::Parameter(&builder, 0, input_shape, "arg_tuple");
|
||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
||||
lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
|
||||
}
|
||||
} else {
|
||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
||||
auto* arg = it.value();
|
||||
auto num = it.index();
|
||||
xla::Shape shape = xla::TypeToShape(arg->getType());
|
||||
lowering[arg] =
|
||||
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
|
||||
++num;
|
||||
}
|
||||
|
||||
for (auto& inst : bb)
|
||||
if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto& inst : bb)
|
||||
if (failed(Lower(&inst, &builder, &lowering))) return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) {
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool always_return_tuple) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
ConvertToHloModule converter(module);
|
||||
ConvertToHloModule converter(module, use_tuple_args, always_return_tuple);
|
||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||
auto hlo_module = converter.ConsumeMainProto();
|
||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||
|
@ -23,8 +23,12 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Converts a MLIR module in HLO dialect into a HloModuleProto.
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto);
|
||||
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
|
||||
// use_tuple_args is set, then functions will have a single tuple as input. If
|
||||
// always_return_tuple is set, then functions will return tuple whether or not
|
||||
// there is only one result.
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool always_return_tuple);
|
||||
|
||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||
// from `value_lowering` map.
|
||||
|
@ -1,6 +1,12 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args %s | FileCheck %s --check-prefix=TUPLE-ARG
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE-RET
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLES
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4]
|
||||
// TUPLE-ARG-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> f32[4]
|
||||
// TUPLE-RET-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[4])
|
||||
// TUPLES-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> (f32[4])
|
||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0)
|
||||
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
|
||||
|
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE
|
||||
|
||||
// Test to verify that multiple result function with always emit return tuple
|
||||
// does not result in nested tuples.
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: s32[4]) -> (s32[4], s32[1,2,3,4])
|
||||
// TUPLE-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (s32[4])) -> (s32[4], s32[1,2,3,4])
|
||||
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) {
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
|
||||
// CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
|
||||
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32>
|
||||
}
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
@ -30,6 +31,18 @@ limitations under the License.
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> emit_use_tuple_arg(
|
||||
"emit-use-tuple-args",
|
||||
llvm::cl::desc("Emit HLO modules using tuples as args"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> emit_always_return_tuple(
|
||||
"emit-always-return-tuple",
|
||||
llvm::cl::desc("Emit HLO modules always return tuple"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
@ -122,7 +135,8 @@ static mlir::LogicalResult MlirHloToHloTranslateFunction(
|
||||
}
|
||||
|
||||
HloProto hloProto;
|
||||
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
|
||||
Status status = mlir::ConvertMlirHloToHlo(
|
||||
module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Module conversion failed: " << status;
|
||||
return mlir::failure();
|
||||
@ -155,7 +169,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
}
|
||||
|
||||
HloProto hloProto;
|
||||
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
|
||||
Status status = mlir::ConvertMlirHloToHlo(
|
||||
module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Module conversion failed: " << status;
|
||||
return mlir::failure();
|
||||
|
Loading…
Reference in New Issue
Block a user