Add option to emit args and results as tuples

Use this to either group all arguments as single input tuple or always return a tuple (even if only result). Add testing flags in translation registration.

PiperOrigin-RevId: 267182876
This commit is contained in:
Jacques Pienaar 2019-09-04 10:37:49 -07:00 committed by TensorFlower Gardener
parent 7b8e672d22
commit 357615e12a
5 changed files with 88 additions and 22 deletions

View File

@ -137,8 +137,12 @@ class ConvertToHloModule {
using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
explicit ConvertToHloModule(mlir::ModuleOp module)
: module_(module), module_builder_("main") {}
explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
bool always_return_tuple)
: module_(module),
module_builder_("main"),
use_tuple_args_(use_tuple_args),
always_return_tuple_(always_return_tuple) {}
// Perform the lowering to XLA. This function returns failure if an error was
// encountered.
@ -160,6 +164,9 @@ class ConvertToHloModule {
}
private:
LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering);
// The module being lowered.
mlir::ModuleOp module_;
@ -168,11 +175,17 @@ class ConvertToHloModule {
// Map between function and lowered computation.
FunctionLoweringMap lowered_computation_;
// Whether the entry function should take a single tuple as input.
bool use_tuple_args_;
// Whether to always return a tuple.
bool always_return_tuple_;
};
LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
ConvertToHloModule::FunctionLoweringMap* function_lowering,
ConvertToHloModule::ValueLoweringMap* value_lowering) {
LogicalResult ConvertToHloModule::Lower(
mlir::Operation* inst, xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering) {
if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success();
// TODO(riverriddle) We currently don't support lowering constant operations.
@ -187,7 +200,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
// values returned, then create a tuple, else return value directly.
xla::XlaOp return_value;
unsigned num_return_values = ret.getNumOperands();
if (num_return_values > 1) {
if (always_return_tuple_ || num_return_values > 1) {
std::vector<xla::XlaOp> returns(num_return_values);
for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) {
returns[i] = value_map[ret.getOperand(i)];
@ -205,7 +218,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
return failure();
}
auto f = inst->getParentOfType<mlir::FuncOp>();
(*function_lowering)[f] = std::move(computation_or.ValueOrDie());
lowered_computation_[f] = std::move(computation_or.ValueOrDie());
return success();
}
inst->emitError("unable to lower operation of type '" +
@ -228,28 +241,42 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
// Mapping from the Value to lowered XlaOp. The code below lowers in
// program order and will fail if an operand is unseen. This can be improved.
ValueLoweringMap lowering;
for (auto& bb : f) {
int num = 0;
for (auto& arg : bb.getArguments()) {
auto& bb = f.front();
// If using tuples as input, then there is only one input
// parameter that is a tuple.
if (use_tuple_args_) {
std::vector<xla::Shape> arg_shapes;
arg_shapes.reserve(bb.getNumArguments());
for (auto& arg : bb.getArguments())
arg_shapes.push_back(xla::TypeToShape(arg->getType()));
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
auto tuple = xla::Parameter(&builder, 0, input_shape, "arg_tuple");
for (auto& it : llvm::enumerate(bb.getArguments())) {
lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
}
} else {
for (auto& it : llvm::enumerate(bb.getArguments())) {
auto* arg = it.value();
auto num = it.index();
xla::Shape shape = xla::TypeToShape(arg->getType());
lowering[arg] =
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
++num;
}
for (auto& inst : bb)
if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering)))
return failure();
}
for (auto& inst : bb)
if (failed(Lower(&inst, &builder, &lowering))) return failure();
return success();
}
} // namespace
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) {
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
bool use_tuple_args, bool always_return_tuple) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
ConvertToHloModule converter(module);
ConvertToHloModule converter(module, use_tuple_args, always_return_tuple);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);

View File

@ -23,8 +23,12 @@ limitations under the License.
namespace mlir {
// Converts a MLIR module in HLO dialect into a HloModuleProto.
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto);
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
// use_tuple_args is set, then functions will have a single tuple as input. If
// always_return_tuple is set, then functions will return tuple whether or not
// there is only one result.
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
bool use_tuple_args, bool always_return_tuple);
// Creates XlaOp equivalent of a given MLIR operation using the operand info
// from `value_lowering` map.

View File

@ -1,6 +1,12 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args %s | FileCheck %s --check-prefix=TUPLE-ARG
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE-RET
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLES
// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4]
// TUPLE-ARG-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> f32[4]
// TUPLE-RET-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[4])
// TUPLES-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> (f32[4])
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0)
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE
// Test to verify that multiple result function with always emit return tuple
// does not result in nested tuples.
// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: s32[4]) -> (s32[4], s32[1,2,3,4])
// TUPLE-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (s32[4])) -> (s32[4], s32[1,2,3,4])
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) {
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
// CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32>
}

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Module.h" // TF:local_config_mlir
@ -30,6 +31,18 @@ limitations under the License.
using stream_executor::port::Status;
using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this
// NOLINTNEXTLINE
static llvm::cl::opt<bool> emit_use_tuple_arg(
"emit-use-tuple-args",
llvm::cl::desc("Emit HLO modules using tuples as args"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> emit_always_return_tuple(
"emit-always-return-tuple",
llvm::cl::desc("Emit HLO modules always return tuple"),
llvm::cl::init(false));
namespace xla {
namespace {
@ -122,7 +135,8 @@ static mlir::LogicalResult MlirHloToHloTranslateFunction(
}
HloProto hloProto;
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
Status status = mlir::ConvertMlirHloToHlo(
module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
@ -155,7 +169,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
}
HloProto hloProto;
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
Status status = mlir::ConvertMlirHloToHlo(
module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();