Update _XlaSharding/xla_hlo.sharding attributes to use serialized string of xla::OpSharding proto instead of textproto format.
Other parts of the bridge use the serialized string format. This updates the use case in HLO MLIR to be consistent with everywhere else. PiperOrigin-RevId: 314859898 Change-Id: Ia4cb3d71c2973f4d18c67e19f46a059d22faa765
This commit is contained in:
parent
f19c6efb4a
commit
ca8c3462f9
@ -381,21 +381,22 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
||||
return output;
|
||||
}
|
||||
|
||||
// Extracts sharding from attribute string.
|
||||
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
|
||||
llvm::StringRef sharding) {
|
||||
xla::OpSharding sharding_proto;
|
||||
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
|
||||
return sharding_proto;
|
||||
}
|
||||
|
||||
// Returns an OpSharding proto from the "sharding" attribute of the op. If the
|
||||
// op doesn't have a sharding attribute or the sharding attribute is invalid,
|
||||
// returns absl::nullopt.
|
||||
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
||||
mlir::Operation* op) {
|
||||
auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
|
||||
if (!sharding) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
::xla::OpSharding sharding_proto;
|
||||
if (!::tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
sharding.getValue().str(), &sharding_proto)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return sharding_proto;
|
||||
if (!sharding) return absl::nullopt;
|
||||
return CreateOpShardingFromStringRef(sharding.getValue());
|
||||
}
|
||||
|
||||
// Checks if all shardings are set.
|
||||
@ -407,14 +408,6 @@ static bool AllOptionalShardingsAreSet(
|
||||
});
|
||||
}
|
||||
|
||||
// Extracts sharding from attribute string.
|
||||
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
|
||||
llvm::StringRef sharding) {
|
||||
xla::OpSharding sharding_proto;
|
||||
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
|
||||
return sharding_proto;
|
||||
}
|
||||
|
||||
// Extracts argument and result shardings from function.
|
||||
static void ExtractShardingsFromFunction(
|
||||
mlir::FuncOp function,
|
||||
|
@ -845,7 +845,19 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
|
||||
func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
|
||||
// CHECK: "xla_hlo.infeed"
|
||||
// An additional sharding is added at the end to account for token result.
|
||||
// CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A"
|
||||
// Proto debug string:
|
||||
// type: TUPLE
|
||||
// tuple_shardings {
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 0
|
||||
// }
|
||||
// tuple_shardings {
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 0
|
||||
// }
|
||||
// CHECK-SAME: xla_hlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00"
|
||||
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
|
||||
return %0 : tensor<8xi32>
|
||||
}
|
||||
|
@ -963,9 +963,19 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// The following op sharding is used:
|
||||
// Proto debug string:
|
||||
// type: OTHER
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_devices: 0
|
||||
// tile_assignment_devices: 1
|
||||
// Serialized string:
|
||||
// "\08\03\1A\02\01\02\22\02\00\01"
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
||||
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "type: OTHER\ntile_assignment_dimensions: 1\ntile_assignment_dimensions: 2\ntile_assignment_devices: 0\ntile_assignment_devices: 1"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
||||
|
@ -3806,17 +3806,15 @@ class ConvertInfeedDequeueTupleOp
|
||||
|
||||
// Token is a control signal and not a real data, so arbitrarily assign
|
||||
// the token to device 0.
|
||||
if (sharding_proto.type() == ::xla::OpSharding::TUPLE)
|
||||
if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
|
||||
*sharding_proto.add_tuple_shardings() =
|
||||
::xla::sharding_builder::AssignDevice(0);
|
||||
|
||||
std::string sharding_str;
|
||||
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
||||
&sharding_str))
|
||||
return failure();
|
||||
|
||||
data_and_token.setAttr(kShardingAttr,
|
||||
rewriter.getStringAttr(sharding_str));
|
||||
data_and_token.setAttr(
|
||||
kShardingAttr,
|
||||
rewriter.getStringAttr(sharding_proto.SerializeAsString()));
|
||||
} else {
|
||||
data_and_token.setAttr(kShardingAttr, op._XlaShardingAttr());
|
||||
}
|
||||
}
|
||||
|
||||
// The infeed instruction produces a tuple of the infeed data and a token
|
||||
@ -4359,21 +4357,12 @@ class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
|
||||
// using a string.
|
||||
if (!op._XlaSharding().hasValue()) return failure();
|
||||
|
||||
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
||||
// proto, so convert to a text form here.
|
||||
::xla::OpSharding sharding_proto;
|
||||
std::string sharding_str;
|
||||
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) ||
|
||||
!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
||||
&sharding_str))
|
||||
return failure();
|
||||
|
||||
auto custom_call = rewriter.create<xla_hlo::CustomCallOp>(
|
||||
op.getLoc(), op.getType(), op.input(),
|
||||
/*call_target_name=*/rewriter.getStringAttr("Sharding"),
|
||||
/*has_side_effect=*/rewriter.getBoolAttr(false),
|
||||
/*backend_config=*/rewriter.getStringAttr(""));
|
||||
custom_call.setAttr(kShardingAttr, rewriter.getStringAttr(sharding_str));
|
||||
custom_call.setAttr(kShardingAttr, op._XlaShardingAttr());
|
||||
rewriter.replaceOp(op, custom_call.getResult());
|
||||
|
||||
return success();
|
||||
|
Loading…
Reference in New Issue
Block a user