From 357615e12a7a76981c1d259eafb04588625513a0 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 4 Sep 2019 10:37:49 -0700 Subject: [PATCH] Add option to emit args and results as tuples Use this to either group all arguments as single input tuple or always return a tuple (even if only result). Add testing flags in translation registration. PiperOrigin-RevId: 267182876 --- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 61 +++++++++++++------ .../compiler/mlir/xla/mlir_hlo_to_hlo.h | 8 ++- .../mlir/xla/tests/translate/add.mlir | 8 ++- .../translate/multiple_return_tuple.mlir | 14 +++++ .../compiler/mlir/xla/xla_mlir_translate.cc | 19 +++++- 5 files changed, 88 insertions(+), 22 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 230044d538b..5b4da82cd3a 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -137,8 +137,12 @@ class ConvertToHloModule { using ValueLoweringMap = llvm::DenseMap; using FunctionLoweringMap = llvm::DenseMap; - explicit ConvertToHloModule(mlir::ModuleOp module) - : module_(module), module_builder_("main") {} + explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args, + bool always_return_tuple) + : module_(module), + module_builder_("main"), + use_tuple_args_(use_tuple_args), + always_return_tuple_(always_return_tuple) {} // Perform the lowering to XLA. This function returns failure if an error was // encountered. @@ -160,6 +164,9 @@ class ConvertToHloModule { } private: + LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, + ConvertToHloModule::ValueLoweringMap* value_lowering); + // The module being lowered. mlir::ModuleOp module_; @@ -168,11 +175,17 @@ class ConvertToHloModule { // Map between function and lowered computation. FunctionLoweringMap lowered_computation_; + + // Whether the entry function should take a single tuple as input. + bool use_tuple_args_; + + // Whether to always return a tuple. + bool always_return_tuple_; }; -LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, - ConvertToHloModule::FunctionLoweringMap* function_lowering, - ConvertToHloModule::ValueLoweringMap* value_lowering) { +LogicalResult ConvertToHloModule::Lower( + mlir::Operation* inst, xla::XlaBuilder* builder, + ConvertToHloModule::ValueLoweringMap* value_lowering) { if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success(); // TODO(riverriddle) We currently don't support lowering constant operations. @@ -187,7 +200,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, // values returned, then create a tuple, else return value directly. xla::XlaOp return_value; unsigned num_return_values = ret.getNumOperands(); - if (num_return_values > 1) { + if (always_return_tuple_ || num_return_values > 1) { std::vector returns(num_return_values); for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) { returns[i] = value_map[ret.getOperand(i)]; @@ -205,7 +218,7 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, return failure(); } auto f = inst->getParentOfType(); - (*function_lowering)[f] = std::move(computation_or.ValueOrDie()); + lowered_computation_[f] = std::move(computation_or.ValueOrDie()); return success(); } inst->emitError("unable to lower operation of type '" + @@ -228,28 +241,42 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { // Mapping from the Value to lowered XlaOp. The code below lowers in // program order and will fail if an operand is unseen. This can be improved. ValueLoweringMap lowering; - for (auto& bb : f) { - int num = 0; - for (auto& arg : bb.getArguments()) { + auto& bb = f.front(); + + // If using tuples as input, then there is only one input + // parameter that is a tuple. + if (use_tuple_args_) { + std::vector arg_shapes; + arg_shapes.reserve(bb.getNumArguments()); + for (auto& arg : bb.getArguments()) + arg_shapes.push_back(xla::TypeToShape(arg->getType())); + xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); + auto tuple = xla::Parameter(&builder, 0, input_shape, "arg_tuple"); + for (auto& it : llvm::enumerate(bb.getArguments())) { + lowering[it.value()] = xla::GetTupleElement(tuple, it.index()); + } + } else { + for (auto& it : llvm::enumerate(bb.getArguments())) { + auto* arg = it.value(); + auto num = it.index(); xla::Shape shape = xla::TypeToShape(arg->getType()); lowering[arg] = xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num)); - ++num; } - - for (auto& inst : bb) - if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering))) - return failure(); } + for (auto& inst : bb) + if (failed(Lower(&inst, &builder, &lowering))) return failure(); + return success(); } } // namespace -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) { +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args, bool always_return_tuple) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - ConvertToHloModule converter(module); + ConvertToHloModule converter(module, use_tuple_args, always_return_tuple); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index b16636f039c..24d20fe7017 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -23,8 +23,12 @@ limitations under the License. namespace mlir { -// Converts a MLIR module in HLO dialect into a HloModuleProto. -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto); +// Converts a MLIR module in HLO dialect into a HloModuleProto. If +// use_tuple_args is set, then functions will have a single tuple as input. If +// always_return_tuple is set, then functions will return tuple whether or not +// there is only one result. +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args, bool always_return_tuple); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir index a77b90ca083..a457ba59e22 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir @@ -1,6 +1,12 @@ // RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args %s | FileCheck %s --check-prefix=TUPLE-ARG +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE-RET +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLES -// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] +// TUPLE-ARG-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> f32[4] +// TUPLE-RET-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[4]) +// TUPLES-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (f32[4], f32[4])) -> (f32[4]) func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0) // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir new file mode 100644 index 00000000000..87817519870 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/multiple_return_tuple.mlir @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text -emit-use-tuple-args -emit-always-return-tuple %s | FileCheck %s --check-prefix=TUPLE + +// Test to verify that multiple result function with always emit return tuple +// does not result in nested tuples. + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: s32[4]) -> (s32[4], s32[1,2,3,4]) +// TUPLE-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (s32[4])) -> (s32[4], s32[1,2,3,4]) +func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) { + // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) + // CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32> +} diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index ad7e4724d90..7fbc5e4e2bc 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/Module.h" // TF:local_config_mlir @@ -30,6 +31,18 @@ limitations under the License. using stream_executor::port::Status; using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this +// NOLINTNEXTLINE +static llvm::cl::opt emit_use_tuple_arg( + "emit-use-tuple-args", + llvm::cl::desc("Emit HLO modules using tuples as args"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt emit_always_return_tuple( + "emit-always-return-tuple", + llvm::cl::desc("Emit HLO modules always return tuple"), + llvm::cl::init(false)); + namespace xla { namespace { @@ -122,7 +135,8 @@ static mlir::LogicalResult MlirHloToHloTranslateFunction( } HloProto hloProto; - Status status = mlir::ConvertMlirHloToHlo(module, &hloProto); + Status status = mlir::ConvertMlirHloToHlo( + module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); @@ -155,7 +169,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( } HloProto hloProto; - Status status = mlir::ConvertMlirHloToHlo(module, &hloProto); + Status status = mlir::ConvertMlirHloToHlo( + module, &hloProto, emit_use_tuple_arg, emit_always_return_tuple); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure();