Support frontend attributes in HLO MLIR and on export to HLO proto.

Ops in HLO can be annotated with frontend_attributes, providing the XLA compiler hints/configuration options for such ops. Currently these are necessary for host and device communication via TPU sends and recvs. This follows a similar approach to OpSharding annotations.

PiperOrigin-RevId: 315027131
Change-Id: I92be98c571968314fa8cddf6d564d28604d1edf6
This commit is contained in:
Andy Ly 2020-06-05 17:43:37 -07:00 committed by TensorFlower Gardener
parent c5ac451efa
commit e59ee30dd2
3 changed files with 79 additions and 4 deletions

View File

@ -73,6 +73,7 @@ constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
constexpr char kShapeIndicesAttr[] = "shape_indices";
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
constexpr char kShardingAttr[] = "xla_hlo.sharding";
constexpr char kFrontendAttributesAttr[] = "xla_hlo.frontend_attributes";
constexpr char kRepicationAttr[] = "xla_hlo.is_same_data_across_replicas";
// Passes through everything except for unique_ptr, on which it calls get().
@ -399,6 +400,25 @@ static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
return CreateOpShardingFromStringRef(sharding.getValue());
}
// Returns a FrontendAttributes proto from the "frontend_attributes" attribute
// of the op. An empty FrontendAttributes proto is returned if an op does not
// have frontend attributes.
static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
mlir::Operation* op) {
xla::FrontendAttributes frontend_attributes;
auto frontend_attributes_dict =
op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
if (!frontend_attributes_dict) return frontend_attributes;
for (const auto& attr : frontend_attributes_dict)
if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
frontend_attributes.mutable_map()->insert(
{attr.first.str(), value_str_attr.getValue().str()});
return frontend_attributes;
}
// Checks if all shardings are set.
static bool AllOptionalShardingsAreSet(
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {

View File

@ -138,6 +138,13 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
"CreateOpShardingFromAttribute(op));\n\n";
// Create a scoped object to assign frontend attributes to generated XLA ops.
// Any HLO can have an attribute of "frontend_attributes", which are used to
// pass hints / configuration options.
os << " xla::XlaScopedFrontendAttributesAssignment "
"frontend_attributes(lowering_context.builder, "
"CreateOpFrontendAttributesFromAttribute(op));\n\n";
// Retrieve all the definitions derived from HLO_Op and sort by record name.
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
// Skip operations that have a custom exporter.

View File

@ -1018,19 +1018,19 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) {
func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
return %0 : tensor<4xui8>
}
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
// CHECK: ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
%1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
return %1 : tensor<*xi32>
@ -1038,4 +1038,52 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = s32[4] parameter(0)
// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
// CHECK: ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
// -----
// Tests ops with different frontend attributes have such attributes set
// correctly in HloModule as frontend_attributes.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token> {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
%1 = "xla_hlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token>
return %1 : tuple<tensor<3x4xf32>, !xla_hlo.token>
}
// CHECK: ENTRY
// CHECK: %[[SEND:.*]] = (f32[3,4], u32[], token[]) send
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_dtoh_0"}
// CHECK: %[[SEND_DONE:.*]] = token[] send-done((f32[3,4], u32[], token[]) %[[SEND]])
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_dtoh_0"}
// CHECK: %[[RECV:.*]] = (f32[3,4], u32[], token[]) recv(token[] %[[SEND_DONE]])
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_htod_0"}
// CHECK: ROOT %{{.*}} = (f32[3,4], token[]) recv-done((f32[3,4], u32[], token[]) %[[RECV]])
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_htod_0"}
// -----
// Tests ops with empty frontend attributes do not have frontend_attributes
// populated in HloModule.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
}
// CHECK-NOT: frontend_attributes
// -----
// Tests ops with no frontend attributes do not have frontend_attributes
// populated in HloModule.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
}
// CHECK-NOT: frontend_attributes