diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 6005fe6e6dd..a3f68411cc3 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -762,17 +762,10 @@ StatusOr HloFunctionImporter::ImportInstruction( ImportInstructionImpl(instruction, func_builder)); if (op == nullptr) return op; - // Best-effort propagation of the layouts. These layouts serve as performance - // hints to the backend. + // See MlirToHloConversionOptions for more about layouts. // // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions // in physical minor-to-major order. - // - // Note that non-array shapes are not carrying layouts, and users have to - // figure out the proper layouts of them through context. This is one of the - // reasons why the attribute-based solution is temporary. - // - // TODO(timshen): Investigate the necessity of having layouts in MHLO. if (instruction->shape().IsArray() && instruction->shape().layout() != LayoutUtil::MakeDescendingLayout( diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index c1d07702100..0923f247cd2 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -499,12 +499,14 @@ class ConvertToHloModule { // single value. explicit ConvertToHloModule( mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, - tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + MlirToHloConversionOptions options) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), return_tuple_(return_tuple), - shape_representation_fn_(shape_representation_fn) { + shape_representation_fn_(shape_representation_fn), + options_(options) { if (!shape_representation_fn_) shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn(); } @@ -585,6 +587,8 @@ class ConvertToHloModule { // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; + + MlirToHloConversionOptions options_; }; } // namespace @@ -1087,18 +1091,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { namespace mlir { namespace { -StatusOr CreateLiteralFromAttr(ElementsAttr attr) { +StatusOr CreateArrayLiteralFromAttr(ElementsAttr attr, + xla::Layout layout) { if (attr.isa()) return tensorflow::errors::Unimplemented( "Opaque elements attr not supported"); xla::Shape shape = xla::TypeToShape(attr.getType()); -#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ - case xla_type: { \ - xla::Array source_data(shape.dimensions()); \ - source_data.SetValues(attr.getValues()); \ - return xla::LiteralUtil::CreateFromArray(source_data); \ +#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ + case xla_type: { \ + xla::Array source_data(shape.dimensions()); \ + source_data.SetValues(attr.getValues()); \ + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \ } switch (shape.element_type()) { @@ -1128,7 +1133,7 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { } xla::Array source_data(shape.dimensions()); source_data.SetValues(values); - return xla::LiteralUtil::CreateFromArray(source_data); + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); } case xla::PrimitiveType::BF16: { xla::Array source_data(shape.dimensions()); @@ -1145,7 +1150,7 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { } source_data.SetValues(values_double); return xla::LiteralUtil::ConvertF64ToBF16( - xla::LiteralUtil::CreateFromArray(source_data)); + xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout)); } default: return tensorflow::errors::Internal(absl::StrCat( @@ -1154,25 +1159,33 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { #undef ELEMENTS_ATTR_TO_LITERAL } +xla::Layout ExtractLayout(mlir::Operation* op, int rank) { + if (auto attr = + op->getAttrOfType("minor_to_major")) { + llvm::SmallVector minor_to_major; + minor_to_major.reserve(attr.size()); + for (const llvm::APInt& i : attr) { + minor_to_major.push_back(i.getZExtValue()); + } + return xla::LayoutUtil::MakeLayout(minor_to_major); + } + return xla::LayoutUtil::MakeDescendingLayout(rank); +} + LogicalResult ConvertToHloModule::Lower( mlir::Operation* inst, bool is_entry_function, llvm::ArrayRef> ret_shardings, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaComputation* result) { - // See hlo_function_importer.cc for documentation about layouts in MHLO. - auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) { - auto attr = - inst->getAttrOfType("minor_to_major"); - if (!attr) return; - - auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op) - ->mutable_shape() - ->mutable_layout() - ->mutable_minor_to_major(); - v->Clear(); - for (const llvm::APInt& i : attr) { - *v->Add() = i.getZExtValue(); + // See MlirToHloConversionOptions for more about layouts. + auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) { + if (options_.propagate_layouts) { + auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op) + ->mutable_shape(); + if (shape->tuple_shapes().empty()) + *shape->mutable_layout() = + ExtractLayout(inst, shape->dimensions().size()).ToProto(); } }; @@ -1216,12 +1229,14 @@ LogicalResult ConvertToHloModule::Lower( } if (matchPattern(inst, m_Constant(&const_attr))) { - auto literal_or = CreateLiteralFromAttr(const_attr); + xla::Layout layout; + layout = ExtractLayout(inst, const_attr.getType().getRank()); + auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout); if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString()); auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); value_map[inst->getResult(0)] = constant; - propagate_layouts(inst, constant); + return success(); } @@ -1674,22 +1689,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, } // namespace Status ConvertRegionToComputation(mlir::Region* region, - xla::XlaComputation* func) { + xla::XlaComputation* func, + MlirToHloConversionOptions options) { mlir::ModuleOp module; - ConvertToHloModule converter(module, true, true, {}); + ConvertToHloModule converter(module, true, true, {}, options); if (failed(converter.LowerRegionAsComputation(region, func))) return tensorflow::errors::Internal( "failed to convert region to computation"); return Status::OK(); } -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple, - const tensorflow::XlaHelpers::ShapeRepresentationFn - shape_representation_fn) { +Status ConvertMlirHloToHlo( + mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, + bool return_tuple, + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + MlirToHloConversionOptions options) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); ConvertToHloModule converter(module, use_tuple_args, return_tuple, - shape_representation_fn); + shape_representation_fn, options); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 6f2b5a6db95..4ca3e586128 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -25,6 +25,18 @@ limitations under the License. namespace mlir { +struct MlirToHloConversionOptions { + // Best-effort propagation of the layouts. These layouts serve as performance + // hints to the backend. + // + // Note that non-array shapes are not carrying layouts, and users have to + // figure out the proper layouts of them through context. This is one of the + // reasons why the attribute-based solution is temporary. + // + // TODO(timshen): Investigate the necessity of having layouts in MHLO. + bool propagate_layouts = false; +}; + // Converts a MLIR module in HLO dialect into a HloModuleProto. If // use_tuple_args is set, then the entry computations's arguments are converted // to a tuple and passed as a single parameter. @@ -32,15 +44,19 @@ namespace mlir { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. +// +// TODO(timshen): move other options into `options`. Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, const tensorflow::XlaHelpers::ShapeRepresentationFn - shape_representation_fn = nullptr); + shape_representation_fn = nullptr, + MlirToHloConversionOptions options = {}); // Converts a region to a computation. It returns a standalone module that // contains the converted region as the entry computation. Status ConvertRegionToComputation(mlir::Region* region, - ::xla::XlaComputation* func); + ::xla::XlaComputation* func, + MlirToHloConversionOptions options = {}); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir index 6a7debc8c6c..2ef0aaf3f50 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir @@ -26,5 +26,9 @@ func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> ten rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64> } : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42") + + // CHECK: s32[1,1]{0,1} constant({ {42} }) + %cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32> + return %0 : tensor<128x64x112x112xf16> } diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 55833bf9939..3ee70db1813 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -129,8 +129,11 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl( if (!module) return mlir::failure(); HloProto hloProto; + mlir::MlirToHloConversionOptions options; + options.propagate_layouts = with_layouts; Status status = mlir::ConvertMlirHloToHlo( - module, &hloProto, emit_use_tuple_arg, emit_return_tuple); + module, &hloProto, emit_use_tuple_arg, emit_return_tuple, + /*shape_representation_fn=*/nullptr, options); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure();