Update tf.InfeedDequeueTuple -> xla_hlo.infeed legalization to insert a default sharding (device 0) to account for token result.
Compared to tf.InfeedDequeueTuple, xla_hlo.infeed has an additional token result. As number of results must match the number of shardings, a sharding is inserted at the end when legalizing to xla_hlo.infeed. PiperOrigin-RevId: 307861971 Change-Id: I8f2828cc6036afc41a72dae88a2218771745e441
This commit is contained in:
parent
7eb1c830f7
commit
1c6795700a
@ -138,6 +138,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
"//tensorflow/compiler/xla/client:sharding_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||
"@llvm-project//llvm:support",
|
||||
|
@ -1171,17 +1171,21 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
|
||||
|
||||
// The following op sharding is used:
|
||||
// Proto debug string:
|
||||
// type: TUPLE
|
||||
// tuple_shardings {
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 0
|
||||
// }
|
||||
// Serialized string:
|
||||
// "\08\01\1A\01\01\22\01\00"
|
||||
// "\08\02*\08\08\01\1A\01\01\22\01\00"
|
||||
|
||||
// CHECK-LABEL: infeed_dequeue_tuple_sharding
|
||||
func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
|
||||
// CHECK: "xla_hlo.infeed"
|
||||
// CHECK-SAME: xla_hlo.sharding = "type: MAXIMAL\0Atile_assignment_dimensions: 1\0Atile_assignment_devices: 0\0A"
|
||||
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
|
||||
// An additional sharding is added at the end to account for token result.
|
||||
// CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A"
|
||||
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
|
||||
return %0 : tensor<8xi32>
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||
@ -3376,9 +3377,17 @@ class ConvertInfeedDequeueTupleOp
|
||||
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
||||
// proto, so convert to a text form here.
|
||||
::xla::OpSharding sharding_proto;
|
||||
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
|
||||
return failure();
|
||||
|
||||
// Token is a control signal and not a real data, so arbitrarily assign
|
||||
// the token to device 0.
|
||||
if (sharding_proto.type() == ::xla::OpSharding::TUPLE)
|
||||
*sharding_proto.add_tuple_shardings() =
|
||||
::xla::sharding_builder::AssignDevice(0);
|
||||
|
||||
std::string sharding_str;
|
||||
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) ||
|
||||
!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
||||
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
||||
&sharding_str))
|
||||
return failure();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user