diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index b07193cc1f6..5c5491f9471 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -138,6 +138,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", "@llvm-project//llvm:support", diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 0650e5afd32..8aacc051453 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1171,17 +1171,21 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { // The following op sharding is used: // Proto debug string: -// type: MAXIMAL -// tile_assignment_dimensions: 1 -// tile_assignment_devices: 0 +// type: TUPLE +// tuple_shardings { +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// } // Serialized string: -// "\08\01\1A\01\01\22\01\00" +// "\08\02*\08\08\01\1A\01\01\22\01\00" // CHECK-LABEL: infeed_dequeue_tuple_sharding func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { // CHECK: "xla_hlo.infeed" - // CHECK-SAME: xla_hlo.sharding = "type: MAXIMAL\0Atile_assignment_dimensions: 1\0Atile_assignment_devices: 0\0A" - %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> + // An additional sharding is added at the end to account for token result. + // CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A" + %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> return %0 : tensor<8xi32> } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 57d7f2b548d..c7f72b921b0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" @@ -3376,9 +3377,17 @@ class ConvertInfeedDequeueTupleOp // _XlaSharding attribute in TF is a serialized string of the OpSharding // proto, so convert to a text form here. ::xla::OpSharding sharding_proto; + if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str())) + return failure(); + + // Token is a control signal and not a real data, so arbitrarily assign + // the token to device 0. + if (sharding_proto.type() == ::xla::OpSharding::TUPLE) + *sharding_proto.add_tuple_shardings() = + ::xla::sharding_builder::AssignDevice(0); + std::string sharding_str; - if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) || - !::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, + if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, &sharding_str)) return failure();