diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 3c670ef0c6e..1de21ebc2a6 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -381,21 +381,22 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( return output; } +// Extracts sharding from attribute string. +static absl::optional CreateOpShardingFromStringRef( + llvm::StringRef sharding) { + xla::OpSharding sharding_proto; + if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt; + return sharding_proto; +} + // Returns an OpSharding proto from the "sharding" attribute of the op. If the // op doesn't have a sharding attribute or the sharding attribute is invalid, // returns absl::nullopt. static absl::optional CreateOpShardingFromAttribute( mlir::Operation* op) { auto sharding = op->getAttrOfType(kShardingAttr); - if (!sharding) { - return absl::nullopt; - } - ::xla::OpSharding sharding_proto; - if (!::tensorflow::protobuf::TextFormat::ParseFromString( - sharding.getValue().str(), &sharding_proto)) { - return absl::nullopt; - } - return sharding_proto; + if (!sharding) return absl::nullopt; + return CreateOpShardingFromStringRef(sharding.getValue()); } // Checks if all shardings are set. @@ -407,14 +408,6 @@ static bool AllOptionalShardingsAreSet( }); } -// Extracts sharding from attribute string. -static absl::optional CreateOpShardingFromStringRef( - llvm::StringRef sharding) { - xla::OpSharding sharding_proto; - if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt; - return sharding_proto; -} - // Extracts argument and result shardings from function. static void ExtractShardingsFromFunction( mlir::FuncOp function, diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 9299cd55943..40cdfb84566 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -845,7 +845,19 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { // CHECK: "xla_hlo.infeed" // An additional sharding is added at the end to account for token result. - // CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A" + // Proto debug string: + // type: TUPLE + // tuple_shardings { + // type: MAXIMAL + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // } + // tuple_shardings { + // type: MAXIMAL + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // } + // CHECK-SAME: xla_hlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> return %0 : tensor<8xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 20b43e8633d..d09e3efb11f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -963,9 +963,19 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // ----- +// The following op sharding is used: +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" + // CHECK: HloModule func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { - %0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "type: OTHER\ntile_assignment_dimensions: 1\ntile_assignment_dimensions: 2\ntile_assignment_devices: 0\ntile_assignment_devices: 1"} : (tensor<16x16xf32>) -> tensor<16x16xf32> + %0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32> return %0 : tensor<16x16xf32> } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index ee0a5f9e190..a575698b276 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -3806,17 +3806,15 @@ class ConvertInfeedDequeueTupleOp // Token is a control signal and not a real data, so arbitrarily assign // the token to device 0. - if (sharding_proto.type() == ::xla::OpSharding::TUPLE) + if (sharding_proto.type() == ::xla::OpSharding::TUPLE) { *sharding_proto.add_tuple_shardings() = ::xla::sharding_builder::AssignDevice(0); - - std::string sharding_str; - if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, - &sharding_str)) - return failure(); - - data_and_token.setAttr(kShardingAttr, - rewriter.getStringAttr(sharding_str)); + data_and_token.setAttr( + kShardingAttr, + rewriter.getStringAttr(sharding_proto.SerializeAsString())); + } else { + data_and_token.setAttr(kShardingAttr, op._XlaShardingAttr()); + } } // The infeed instruction produces a tuple of the infeed data and a token @@ -4359,21 +4357,12 @@ class ConvertXlaShardingOp : public OpRewritePattern { // using a string. if (!op._XlaSharding().hasValue()) return failure(); - // _XlaSharding attribute in TF is a serialized string of the OpSharding - // proto, so convert to a text form here. - ::xla::OpSharding sharding_proto; - std::string sharding_str; - if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) || - !::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, - &sharding_str)) - return failure(); - auto custom_call = rewriter.create( op.getLoc(), op.getType(), op.input(), /*call_target_name=*/rewriter.getStringAttr("Sharding"), /*has_side_effect=*/rewriter.getBoolAttr(false), /*backend_config=*/rewriter.getStringAttr("")); - custom_call.setAttr(kShardingAttr, rewriter.getStringAttr(sharding_str)); + custom_call.setAttr(kShardingAttr, op._XlaShardingAttr()); rewriter.replaceOp(op, custom_call.getResult()); return success();