Update tf.InfeedDequeueTuple -> xla_hlo.infeed legalization to insert a default sharding (device 0) to account for token result.
Compared to tf.InfeedDequeueTuple, xla_hlo.infeed has an additional token result. As number of results must match the number of shardings, a sharding is inserted at the end when legalizing to xla_hlo.infeed. PiperOrigin-RevId: 307861971 Change-Id: I8f2828cc6036afc41a72dae88a2218771745e441
This commit is contained in:
parent
7eb1c830f7
commit
1c6795700a
tensorflow/compiler/mlir/xla
@ -138,6 +138,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:padding",
|
"//tensorflow/compiler/xla/client:padding",
|
||||||
|
"//tensorflow/compiler/xla/client:sharding_builder",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
|
@ -1171,17 +1171,21 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
|
|||||||
|
|
||||||
// The following op sharding is used:
|
// The following op sharding is used:
|
||||||
// Proto debug string:
|
// Proto debug string:
|
||||||
|
// type: TUPLE
|
||||||
|
// tuple_shardings {
|
||||||
// type: MAXIMAL
|
// type: MAXIMAL
|
||||||
// tile_assignment_dimensions: 1
|
// tile_assignment_dimensions: 1
|
||||||
// tile_assignment_devices: 0
|
// tile_assignment_devices: 0
|
||||||
|
// }
|
||||||
// Serialized string:
|
// Serialized string:
|
||||||
// "\08\01\1A\01\01\22\01\00"
|
// "\08\02*\08\08\01\1A\01\01\22\01\00"
|
||||||
|
|
||||||
// CHECK-LABEL: infeed_dequeue_tuple_sharding
|
// CHECK-LABEL: infeed_dequeue_tuple_sharding
|
||||||
func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
|
func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
|
||||||
// CHECK: "xla_hlo.infeed"
|
// CHECK: "xla_hlo.infeed"
|
||||||
// CHECK-SAME: xla_hlo.sharding = "type: MAXIMAL\0Atile_assignment_dimensions: 1\0Atile_assignment_devices: 0\0A"
|
// An additional sharding is added at the end to account for token result.
|
||||||
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
|
// CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A"
|
||||||
|
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
|
||||||
return %0 : tensor<8xi32>
|
return %0 : tensor<8xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/xla/client/padding.h"
|
#include "tensorflow/compiler/xla/client/padding.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||||
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||||
@ -3376,9 +3377,17 @@ class ConvertInfeedDequeueTupleOp
|
|||||||
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
// _XlaSharding attribute in TF is a serialized string of the OpSharding
|
||||||
// proto, so convert to a text form here.
|
// proto, so convert to a text form here.
|
||||||
::xla::OpSharding sharding_proto;
|
::xla::OpSharding sharding_proto;
|
||||||
|
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Token is a control signal and not a real data, so arbitrarily assign
|
||||||
|
// the token to device 0.
|
||||||
|
if (sharding_proto.type() == ::xla::OpSharding::TUPLE)
|
||||||
|
*sharding_proto.add_tuple_shardings() =
|
||||||
|
::xla::sharding_builder::AssignDevice(0);
|
||||||
|
|
||||||
std::string sharding_str;
|
std::string sharding_str;
|
||||||
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) ||
|
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
||||||
!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
|
|
||||||
&sharding_str))
|
&sharding_str))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user