[MLIR] Add a switch for enabling layouts in MLIR -> XLA HLO
Previously MLIR may carry an optional attribute "minor_to_major", to indicate a layout. With layouts, the lack of such attribute means descending layout. However, XLA builders don't put descending layouts by default. Sometimes the shape inferencer forwards layouts from the input. Those layouts are not meant to be forwarded. Add a switch to explicitly assign layouts to all ops. Also fixed a bug that literal's layout is not correctly set. PiperOrigin-RevId: 335091078 Change-Id: I4649afc78e401806dadc6165ba819d1282a18147
This commit is contained in:
parent
724b82d94c
commit
a70cc367bc
@ -762,17 +762,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
ImportInstructionImpl(instruction, func_builder));
|
ImportInstructionImpl(instruction, func_builder));
|
||||||
if (op == nullptr) return op;
|
if (op == nullptr) return op;
|
||||||
|
|
||||||
// Best-effort propagation of the layouts. These layouts serve as performance
|
// See MlirToHloConversionOptions for more about layouts.
|
||||||
// hints to the backend.
|
|
||||||
//
|
//
|
||||||
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
|
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
|
||||||
// in physical minor-to-major order.
|
// in physical minor-to-major order.
|
||||||
//
|
|
||||||
// Note that non-array shapes are not carrying layouts, and users have to
|
|
||||||
// figure out the proper layouts of them through context. This is one of the
|
|
||||||
// reasons why the attribute-based solution is temporary.
|
|
||||||
//
|
|
||||||
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
|
||||||
if (instruction->shape().IsArray() &&
|
if (instruction->shape().IsArray() &&
|
||||||
instruction->shape().layout() !=
|
instruction->shape().layout() !=
|
||||||
LayoutUtil::MakeDescendingLayout(
|
LayoutUtil::MakeDescendingLayout(
|
||||||
|
@ -499,12 +499,14 @@ class ConvertToHloModule {
|
|||||||
// single value.
|
// single value.
|
||||||
explicit ConvertToHloModule(
|
explicit ConvertToHloModule(
|
||||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
|
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
MlirToHloConversionOptions options)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
module_builder_("main"),
|
module_builder_("main"),
|
||||||
use_tuple_args_(use_tuple_args),
|
use_tuple_args_(use_tuple_args),
|
||||||
return_tuple_(return_tuple),
|
return_tuple_(return_tuple),
|
||||||
shape_representation_fn_(shape_representation_fn) {
|
shape_representation_fn_(shape_representation_fn),
|
||||||
|
options_(options) {
|
||||||
if (!shape_representation_fn_)
|
if (!shape_representation_fn_)
|
||||||
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
||||||
}
|
}
|
||||||
@ -585,6 +587,8 @@ class ConvertToHloModule {
|
|||||||
|
|
||||||
// Unique suffix to give to the name of the next lowered region.
|
// Unique suffix to give to the name of the next lowered region.
|
||||||
size_t region_id_ = 0;
|
size_t region_id_ = 0;
|
||||||
|
|
||||||
|
MlirToHloConversionOptions options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -1087,18 +1091,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
|
||||||
|
xla::Layout layout) {
|
||||||
if (attr.isa<OpaqueElementsAttr>())
|
if (attr.isa<OpaqueElementsAttr>())
|
||||||
return tensorflow::errors::Unimplemented(
|
return tensorflow::errors::Unimplemented(
|
||||||
"Opaque elements attr not supported");
|
"Opaque elements attr not supported");
|
||||||
|
|
||||||
xla::Shape shape = xla::TypeToShape(attr.getType());
|
xla::Shape shape = xla::TypeToShape(attr.getType());
|
||||||
|
|
||||||
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
|
#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
|
||||||
case xla_type: { \
|
case xla_type: { \
|
||||||
xla::Array<cpp_type> source_data(shape.dimensions()); \
|
xla::Array<cpp_type> source_data(shape.dimensions()); \
|
||||||
source_data.SetValues(attr.getValues<cpp_type>()); \
|
source_data.SetValues(attr.getValues<cpp_type>()); \
|
||||||
return xla::LiteralUtil::CreateFromArray(source_data); \
|
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (shape.element_type()) {
|
switch (shape.element_type()) {
|
||||||
@ -1128,7 +1133,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
}
|
}
|
||||||
xla::Array<xla::half> source_data(shape.dimensions());
|
xla::Array<xla::half> source_data(shape.dimensions());
|
||||||
source_data.SetValues(values);
|
source_data.SetValues(values);
|
||||||
return xla::LiteralUtil::CreateFromArray(source_data);
|
return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout);
|
||||||
}
|
}
|
||||||
case xla::PrimitiveType::BF16: {
|
case xla::PrimitiveType::BF16: {
|
||||||
xla::Array<double> source_data(shape.dimensions());
|
xla::Array<double> source_data(shape.dimensions());
|
||||||
@ -1145,7 +1150,7 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
}
|
}
|
||||||
source_data.SetValues(values_double);
|
source_data.SetValues(values_double);
|
||||||
return xla::LiteralUtil::ConvertF64ToBF16(
|
return xla::LiteralUtil::ConvertF64ToBF16(
|
||||||
xla::LiteralUtil::CreateFromArray(source_data));
|
xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout));
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return tensorflow::errors::Internal(absl::StrCat(
|
return tensorflow::errors::Internal(absl::StrCat(
|
||||||
@ -1154,25 +1159,33 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
#undef ELEMENTS_ATTR_TO_LITERAL
|
#undef ELEMENTS_ATTR_TO_LITERAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::Layout ExtractLayout(mlir::Operation* op, int rank) {
|
||||||
|
if (auto attr =
|
||||||
|
op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major")) {
|
||||||
|
llvm::SmallVector<int64, 4> minor_to_major;
|
||||||
|
minor_to_major.reserve(attr.size());
|
||||||
|
for (const llvm::APInt& i : attr) {
|
||||||
|
minor_to_major.push_back(i.getZExtValue());
|
||||||
|
}
|
||||||
|
return xla::LayoutUtil::MakeLayout(minor_to_major);
|
||||||
|
}
|
||||||
|
return xla::LayoutUtil::MakeDescendingLayout(rank);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult ConvertToHloModule::Lower(
|
LogicalResult ConvertToHloModule::Lower(
|
||||||
mlir::Operation* inst, bool is_entry_function,
|
mlir::Operation* inst, bool is_entry_function,
|
||||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaBuilder* builder,
|
xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||||
xla::XlaComputation* result) {
|
xla::XlaComputation* result) {
|
||||||
// See hlo_function_importer.cc for documentation about layouts in MHLO.
|
// See MlirToHloConversionOptions for more about layouts.
|
||||||
auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) {
|
auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||||
auto attr =
|
if (options_.propagate_layouts) {
|
||||||
inst->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
|
auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
||||||
if (!attr) return;
|
->mutable_shape();
|
||||||
|
if (shape->tuple_shapes().empty())
|
||||||
auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
*shape->mutable_layout() =
|
||||||
->mutable_shape()
|
ExtractLayout(inst, shape->dimensions().size()).ToProto();
|
||||||
->mutable_layout()
|
|
||||||
->mutable_minor_to_major();
|
|
||||||
v->Clear();
|
|
||||||
for (const llvm::APInt& i : attr) {
|
|
||||||
*v->Add() = i.getZExtValue();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1216,12 +1229,14 @@ LogicalResult ConvertToHloModule::Lower(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (matchPattern(inst, m_Constant(&const_attr))) {
|
if (matchPattern(inst, m_Constant(&const_attr))) {
|
||||||
auto literal_or = CreateLiteralFromAttr(const_attr);
|
xla::Layout layout;
|
||||||
|
layout = ExtractLayout(inst, const_attr.getType().getRank());
|
||||||
|
auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout);
|
||||||
if (!literal_or.ok())
|
if (!literal_or.ok())
|
||||||
return inst->emitError(literal_or.status().ToString());
|
return inst->emitError(literal_or.status().ToString());
|
||||||
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
||||||
value_map[inst->getResult(0)] = constant;
|
value_map[inst->getResult(0)] = constant;
|
||||||
propagate_layouts(inst, constant);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1674,22 +1689,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertRegionToComputation(mlir::Region* region,
|
Status ConvertRegionToComputation(mlir::Region* region,
|
||||||
xla::XlaComputation* func) {
|
xla::XlaComputation* func,
|
||||||
|
MlirToHloConversionOptions options) {
|
||||||
mlir::ModuleOp module;
|
mlir::ModuleOp module;
|
||||||
ConvertToHloModule converter(module, true, true, {});
|
ConvertToHloModule converter(module, true, true, {}, options);
|
||||||
if (failed(converter.LowerRegionAsComputation(region, func)))
|
if (failed(converter.LowerRegionAsComputation(region, func)))
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
"failed to convert region to computation");
|
"failed to convert region to computation");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(
|
||||||
bool use_tuple_args, bool return_tuple,
|
mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args,
|
||||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
bool return_tuple,
|
||||||
shape_representation_fn) {
|
const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
MlirToHloConversionOptions options) {
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||||
shape_representation_fn);
|
shape_representation_fn, options);
|
||||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||||
auto hlo_module = converter.ConsumeMainProto();
|
auto hlo_module = converter.ConsumeMainProto();
|
||||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||||
|
@ -25,6 +25,18 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
struct MlirToHloConversionOptions {
|
||||||
|
// Best-effort propagation of the layouts. These layouts serve as performance
|
||||||
|
// hints to the backend.
|
||||||
|
//
|
||||||
|
// Note that non-array shapes are not carrying layouts, and users have to
|
||||||
|
// figure out the proper layouts of them through context. This is one of the
|
||||||
|
// reasons why the attribute-based solution is temporary.
|
||||||
|
//
|
||||||
|
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
||||||
|
bool propagate_layouts = false;
|
||||||
|
};
|
||||||
|
|
||||||
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
|
// Converts a MLIR module in HLO dialect into a HloModuleProto. If
|
||||||
// use_tuple_args is set, then the entry computations's arguments are converted
|
// use_tuple_args is set, then the entry computations's arguments are converted
|
||||||
// to a tuple and passed as a single parameter.
|
// to a tuple and passed as a single parameter.
|
||||||
@ -32,15 +44,19 @@ namespace mlir {
|
|||||||
// are converted to a tuple even when there is only a single return value.
|
// are converted to a tuple even when there is only a single return value.
|
||||||
// Multiple return values are always converted to a tuple and returned as a
|
// Multiple return values are always converted to a tuple and returned as a
|
||||||
// single value.
|
// single value.
|
||||||
|
//
|
||||||
|
// TODO(timshen): move other options into `options`.
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
||||||
bool use_tuple_args, bool return_tuple,
|
bool use_tuple_args, bool return_tuple,
|
||||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||||
shape_representation_fn = nullptr);
|
shape_representation_fn = nullptr,
|
||||||
|
MlirToHloConversionOptions options = {});
|
||||||
|
|
||||||
// Converts a region to a computation. It returns a standalone module that
|
// Converts a region to a computation. It returns a standalone module that
|
||||||
// contains the converted region as the entry computation.
|
// contains the converted region as the entry computation.
|
||||||
Status ConvertRegionToComputation(mlir::Region* region,
|
Status ConvertRegionToComputation(mlir::Region* region,
|
||||||
::xla::XlaComputation* func);
|
::xla::XlaComputation* func,
|
||||||
|
MlirToHloConversionOptions options = {});
|
||||||
|
|
||||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||||
// from `value_lowering` map.
|
// from `value_lowering` map.
|
||||||
|
@ -26,5 +26,9 @@ func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> ten
|
|||||||
rhs_dilations = dense<1> : tensor<2xi64>,
|
rhs_dilations = dense<1> : tensor<2xi64>,
|
||||||
window_strides = dense<2> : tensor<2xi64>
|
window_strides = dense<2> : tensor<2xi64>
|
||||||
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
|
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
|
||||||
|
|
||||||
|
// CHECK: s32[1,1]{0,1} constant({ {42} })
|
||||||
|
%cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32>
|
||||||
|
|
||||||
return %0 : tensor<128x64x112x112xf16>
|
return %0 : tensor<128x64x112x112xf16>
|
||||||
}
|
}
|
||||||
|
@ -129,8 +129,11 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
|||||||
if (!module) return mlir::failure();
|
if (!module) return mlir::failure();
|
||||||
|
|
||||||
HloProto hloProto;
|
HloProto hloProto;
|
||||||
|
mlir::MlirToHloConversionOptions options;
|
||||||
|
options.propagate_layouts = with_layouts;
|
||||||
Status status = mlir::ConvertMlirHloToHlo(
|
Status status = mlir::ConvertMlirHloToHlo(
|
||||||
module, &hloProto, emit_use_tuple_arg, emit_return_tuple);
|
module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
|
||||||
|
/*shape_representation_fn=*/nullptr, options);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Module conversion failed: " << status;
|
LOG(ERROR) << "Module conversion failed: " << status;
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user