Support frontend attributes in HLO MLIR and on export to HLO proto.
Ops in HLO can be annotated with frontend_attributes, providing the XLA compiler hints/configuration options for such ops. Currently these are necessary for host and device communication via TPU sends and recvs. This follows a similar approach to OpSharding annotations. PiperOrigin-RevId: 315027131 Change-Id: I92be98c571968314fa8cddf6d564d28604d1edf6
This commit is contained in:
parent
c5ac451efa
commit
e59ee30dd2
@ -73,6 +73,7 @@ constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
|
||||
constexpr char kShapeIndicesAttr[] = "shape_indices";
|
||||
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
|
||||
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||
constexpr char kFrontendAttributesAttr[] = "xla_hlo.frontend_attributes";
|
||||
constexpr char kRepicationAttr[] = "xla_hlo.is_same_data_across_replicas";
|
||||
|
||||
// Passes through everything except for unique_ptr, on which it calls get().
|
||||
@ -399,6 +400,25 @@ static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
||||
return CreateOpShardingFromStringRef(sharding.getValue());
|
||||
}
|
||||
|
||||
// Returns a FrontendAttributes proto from the "frontend_attributes" attribute
|
||||
// of the op. An empty FrontendAttributes proto is returned if an op does not
|
||||
// have frontend attributes.
|
||||
static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
|
||||
mlir::Operation* op) {
|
||||
xla::FrontendAttributes frontend_attributes;
|
||||
auto frontend_attributes_dict =
|
||||
op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
|
||||
|
||||
if (!frontend_attributes_dict) return frontend_attributes;
|
||||
|
||||
for (const auto& attr : frontend_attributes_dict)
|
||||
if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
|
||||
frontend_attributes.mutable_map()->insert(
|
||||
{attr.first.str(), value_str_attr.getValue().str()});
|
||||
|
||||
return frontend_attributes;
|
||||
}
|
||||
|
||||
// Checks if all shardings are set.
|
||||
static bool AllOptionalShardingsAreSet(
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
|
||||
|
@ -138,6 +138,13 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
|
||||
"CreateOpShardingFromAttribute(op));\n\n";
|
||||
|
||||
// Create a scoped object to assign frontend attributes to generated XLA ops.
|
||||
// Any HLO can have an attribute of "frontend_attributes", which are used to
|
||||
// pass hints / configuration options.
|
||||
os << " xla::XlaScopedFrontendAttributesAssignment "
|
||||
"frontend_attributes(lowering_context.builder, "
|
||||
"CreateOpFrontendAttributesFromAttribute(op));\n\n";
|
||||
|
||||
// Retrieve all the definitions derived from HLO_Op and sort by record name.
|
||||
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
|
||||
// Skip operations that have a custom exporter.
|
||||
|
@ -1018,19 +1018,19 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) {
|
||||
func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
|
||||
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
|
||||
return %0 : tensor<4xui8>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
|
||||
// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
|
||||
// CHECK: ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
|
||||
%0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
|
||||
return %1 : tensor<*xi32>
|
||||
@ -1038,4 +1038,52 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: %[[ARG0:.*]] = s32[4] parameter(0)
|
||||
// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
|
||||
// CHECK: ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
|
||||
|
||||
// -----
|
||||
|
||||
// Tests ops with different frontend attributes have such attributes set
|
||||
// correctly in HloModule as frontend_attributes.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token> {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
%1 = "xla_hlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token>
|
||||
return %1 : tuple<tensor<3x4xf32>, !xla_hlo.token>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
// CHECK: %[[SEND:.*]] = (f32[3,4], u32[], token[]) send
|
||||
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_dtoh_0"}
|
||||
// CHECK: %[[SEND_DONE:.*]] = token[] send-done((f32[3,4], u32[], token[]) %[[SEND]])
|
||||
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_dtoh_0"}
|
||||
// CHECK: %[[RECV:.*]] = (f32[3,4], u32[], token[]) recv(token[] %[[SEND_DONE]])
|
||||
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_htod_0"}
|
||||
// CHECK: ROOT %{{.*}} = (f32[3,4], token[]) recv-done((f32[3,4], u32[], token[]) %[[RECV]])
|
||||
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_htod_0"}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests ops with empty frontend attributes do not have frontend_attributes
|
||||
// populated in HloModule.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
}
|
||||
|
||||
// CHECK-NOT: frontend_attributes
|
||||
|
||||
// -----
|
||||
|
||||
// Tests ops with no frontend attributes do not have frontend_attributes
|
||||
// populated in HloModule.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
}
|
||||
|
||||
// CHECK-NOT: frontend_attributes
|
||||
|
Loading…
x
Reference in New Issue
Block a user