[MLIR] Add a switch for enabling layouts in MLIR -> XLA HLO

Previously MLIR may carry an optional attribute "minor_to_major", to indicate a
layout. With layouts, the lack of such attribute means descending layout. However,
XLA builders don't put descending layouts by default. Sometimes the shape
inferencer forwards layouts from the input. Those layouts are not meant to be forwarded.

Add a switch to explicitly assign layouts to all ops.

Also fixed a bug that literal's layout is not correctly set.

PiperOrigin-RevId: 335091078
Change-Id: I4649afc78e401806dadc6165ba819d1282a18147
This commit is contained in:
Tim Shen 2020-10-02 13:35:19 -07:00 committed by TensorFlower Gardener
parent 724b82d94c
commit a70cc367bc
5 changed files with 76 additions and 43 deletions

View File

@ -762,17 +762,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
ImportInstructionImpl(instruction, func_builder));
if (op == nullptr) return op;
// Best-effort propagation of the layouts. These layouts serve as performance
// hints to the backend.
// See MlirToHloConversionOptions for more about layouts.
//
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
// in physical minor-to-major order.
//
// Note that non-array shapes are not carrying layouts, and users have to
// figure out the proper layouts of them through context. This is one of the
// reasons why the attribute-based solution is temporary.
//
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
if (instruction->shape().IsArray() &&
instruction->shape().layout() !=
LayoutUtil::MakeDescendingLayout(

View File

@ -499,12 +499,14 @@ class ConvertToHloModule {
// single value.
explicit ConvertToHloModule(
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
MlirToHloConversionOptions options)
: module_(module),
module_builder_("main"),
use_tuple_args_(use_tuple_args),
return_tuple_(return_tuple),
shape_representation_fn_(shape_representation_fn) {
shape_representation_fn_(shape_representation_fn),
options_(options) {
if (!shape_representation_fn_)
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
}
@ -585,6 +587,8 @@ class ConvertToHloModule {
// Unique suffix to give to the name of the next lowered region.
size_t region_id_ = 0;
MlirToHloConversionOptions options_;
};
} // namespace
@ -1087,18 +1091,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
namespace mlir {
namespace {
StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
xla::Layout layout) {
if (attr.isa<OpaqueElementsAttr>())
return tensorflow::errors::Unimplemented(
"Opaque elements attr not supported");
xla::Shape shape = xla::TypeToShape(attr.getType());
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
case xla_type: { \
xla::Array<cpp_type> source_data(shape.dimensions()); \
source_data.SetValues(attr.getValues<cpp_type>()); \
return xla::LiteralUtil::CreateFromArray(source_data); \
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
case xla_type: { \
xla::Array<cpp_type> source_data(shape.dimensions()); \
source_data.SetValues(attr.getValues<cpp_type>()); \
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
}
switch (shape.element_type()) {
@ -1128,7 +1133,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
}
xla::Array<xla::half> source_data(shape.dimensions());
source_data.SetValues(values);
return xla::LiteralUtil::CreateFromArray(source_data);
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout);
}
case xla::PrimitiveType::BF16: {
xla::Array<double> source_data(shape.dimensions());
@ -1145,7 +1150,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
}
source_data.SetValues(values_double);
return xla::LiteralUtil::ConvertF64ToBF16(
xla::LiteralUtil::CreateFromArray(source_data));
xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout));
}
default:
return tensorflow::errors::Internal(absl::StrCat(
@ -1154,25 +1159,33 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
#undef ELEMENTS_ATTR_TO_LITERAL
}
xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
if (auto attr =
op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major")) {
llvm::SmallVector<int64, 4> minor_to_major;
minor_to_major.reserve(attr.size());
for (const llvm::APInt& i : attr) {
minor_to_major.push_back(i.getZExtValue());
}
return xla::LayoutUtil::MakeLayout(minor_to_major);
}
return xla::LayoutUtil::MakeDescendingLayout(rank);
}
LogicalResult ConvertToHloModule::Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaComputation* result) {
// See hlo_function_importer.cc for documentation about layouts in MHLO.
auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) {
auto attr =
inst->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
if (!attr) return;
auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
->mutable_shape()
->mutable_layout()
->mutable_minor_to_major();
v->Clear();
for (const llvm::APInt& i : attr) {
*v->Add() = i.getZExtValue();
// See MlirToHloConversionOptions for more about layouts.
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
if (options_.propagate_layouts) {
auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
->mutable_shape();
if (shape->tuple_shapes().empty())
*shape->mutable_layout() =
ExtractLayout(inst, shape->dimensions().size()).ToProto();
}
};
@ -1216,12 +1229,14 @@ LogicalResult ConvertToHloModule::Lower(
}
if (matchPattern(inst, m_Constant(&const_attr))) {
auto literal_or = CreateLiteralFromAttr(const_attr);
xla::Layout layout;
layout = ExtractLayout(inst, const_attr.getType().getRank());
auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
if (!literal_or.ok())
return inst->emitError(literal_or.status().ToString());
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
value_map[inst->getResult(0)] = constant;
propagate_layouts(inst, constant);
return success();
}
@ -1674,22 +1689,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
} // namespace
Status ConvertRegionToComputation(mlir::Region* region,
xla::XlaComputation* func) {
xla::XlaComputation* func,
MlirToHloConversionOptions options) {
mlir::ModuleOp module;
ConvertToHloModule converter(module, true, true, {});
ConvertToHloModule converter(module, true, true, {}, options);
if (failed(converter.LowerRegionAsComputation(region, func)))
return tensorflow::errors::Internal(
"failed to convert region to computation");
return Status::OK();
}
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
bool use_tuple_args, bool return_tuple,
const tensorflow::XlaHelpers::ShapeRepresentationFn
shape_representation_fn) {
Status ConvertMlirHloToHlo(
mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
bool return_tuple,
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
MlirToHloConversionOptions options) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
shape_representation_fn);
shape_representation_fn, options);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);

View File

@ -25,6 +25,18 @@ limitations under the License.
namespace mlir {
struct MlirToHloConversionOptions {
// Best-effort propagation of the layouts. These layouts serve as performance
// hints to the backend.
//
// Note that non-array shapes are not carrying layouts, and users have to
// figure out the proper layouts of them through context. This is one of the
// reasons why the attribute-based solution is temporary.
//
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
bool propagate_layouts = false;
};
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
// use_tuple_args is set, then the entry computations's arguments are converted
// to a tuple and passed as a single parameter.
@ -32,15 +44,19 @@ namespace mlir {
// are converted to a tuple even when there is only a single return value.
// Multiple return values are always converted to a tuple and returned as a
// single value.
//
// TODO(timshen): move other options into `options`.
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
bool use_tuple_args, bool return_tuple,
const tensorflow::XlaHelpers::ShapeRepresentationFn
shape_representation_fn = nullptr);
shape_representation_fn = nullptr,
MlirToHloConversionOptions options = {});
// Converts a region to a computation. It returns a standalone module that
// contains the converted region as the entry computation.
Status ConvertRegionToComputation(mlir::Region* region,
::xla::XlaComputation* func);
::xla::XlaComputation* func,
MlirToHloConversionOptions options = {});
// Creates XlaOp equivalent of a given MLIR operation using the operand info
// from `value_lowering` map.

View File

@ -26,5 +26,9 @@ func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> ten
rhs_dilations = dense<1> : tensor<2xi64>,
window_strides = dense<2> : tensor<2xi64>
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
// CHECK: s32[1,1]{0,1} constant({ {42} })
%cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32>
return %0 : tensor<128x64x112x112xf16>
}

View File

@ -129,8 +129,11 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
if (!module) return mlir::failure();
HloProto hloProto;
mlir::MlirToHloConversionOptions options;
options.propagate_layouts = with_layouts;
Status status = mlir::ConvertMlirHloToHlo(
module, &hloProto, emit_use_tuple_arg, emit_return_tuple);
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
/*shape_representation_fn=*/nullptr, options);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();