Add an explicit sharding
attribute to XlaSharding op. This fixes the issue which the user added attribute cannot be found when computing gradients.
PiperOrigin-RevId: 353364146 Change-Id: Iee02f8df1fb0c987bf5866dfb112361d9c9e374e
This commit is contained in:
parent
01a9c4f7c0
commit
d56de8e17e
@ -1013,7 +1013,6 @@ An op which shards the input based on the given sharding attribute.
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
|
||||
DefaultValuedAttr<StrAttr, "">:$sharding,
|
||||
OptionalAttr<StrAttr>:$_XlaSharding
|
||||
);
|
||||
|
||||
|
@ -1636,7 +1636,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
}
|
||||
func @tpu0_func(%arg0: tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %1, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -1704,7 +1704,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -1772,7 +1772,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -1815,7 +1815,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -1855,7 +1855,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi1>, tensor<*xi32>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %3, %4 : tensor<*xi1>, tensor<*xi32>
|
||||
}
|
||||
}
|
||||
@ -1961,7 +1961,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -2068,7 +2068,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -2152,7 +2152,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -2237,7 +2237,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
@ -2321,7 +2321,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %3 : tensor<*xi32>, tensor<*xi1>
|
||||
}
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
@ -97,11 +97,11 @@ func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: te
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%2, %3 = "tf.A"(%0, %1) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %5 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
@ -123,11 +123,11 @@ func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3, %4 = "tf.A"(%1, %2) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %5, %6 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
@ -149,13 +149,13 @@ func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%3 = "tf.Identity"(%2) : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%5, %6 = "tf.A"(%1, %3) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %7, %8 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
@ -178,11 +178,11 @@ func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
|
||||
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%4, %5 = "tf.A"(%2, %3) : (tensor<*xi1>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %6, %7 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
@ -208,22 +208,22 @@ func @func_with_device_training_loop(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>)
|
||||
%2 = "tf.PartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_func_body} : (tensor<*xi1>) -> (tensor<i32>)
|
||||
%3, %4 = "tf.A"(%1#0, %2) : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
|
||||
return %5, %6 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_body
|
||||
func @func_body(%arg0: tensor<*xi32>)-> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%2, %3 = "tf.C"(%1) : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
return %2, %3 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pcall_func_body
|
||||
func @pcall_func_body(%arg0: tensor<*xi1>) -> tensor<i32> {
|
||||
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%2 = "tf.D"(%1) : (tensor<*xi1>) -> (tensor<i32>)
|
||||
return %2 : tensor<i32>
|
||||
}
|
||||
@ -248,7 +248,7 @@ func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
}
|
||||
|
||||
func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Identity"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
@ -4804,7 +4804,7 @@ func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xb
|
||||
// CHECK-LABEL: xla_sharding
|
||||
func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
|
||||
// CHECK-NEXT: "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""}
|
||||
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
|
||||
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
|
||||
return %0 : tensor<4x16xf32>
|
||||
}
|
||||
|
||||
|
@ -899,7 +899,6 @@ REGISTER_OP("XlaSharding")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("sharding: string = ''")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
An op which shards the input based on the given sharding attribute.
|
||||
|
@ -444,11 +444,10 @@ sharding = gen_xla_ops.xla_sharding
|
||||
|
||||
@ops.RegisterGradient("XlaSharding")
|
||||
def _sharding_grad(op, grad):
|
||||
sharding_attr = op.get_attr("sharding")
|
||||
grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
|
||||
grad_sharding = gen_xla_ops.xla_sharding(grad)
|
||||
# pylint: disable=protected-access
|
||||
grad_sharding.op._set_attr("_XlaSharding",
|
||||
attr_value_pb2.AttrValue(s=sharding_attr))
|
||||
grad_sharding.op._set_attr(
|
||||
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
|
||||
return [grad_sharding]
|
||||
|
||||
|
||||
|
@ -161,19 +161,12 @@ class Sharding(object):
|
||||
tile_assignment_dimensions=tile_assignment_dims,
|
||||
tile_assignment_devices=range(num_devices)))
|
||||
|
||||
def apply_to_tensor(self,
|
||||
tensor,
|
||||
assign_tuple_sharding=False,
|
||||
use_sharding_op=False):
|
||||
def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
|
||||
"""Applies this Sharding attribute to `tensor`.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor to split.
|
||||
assign_tuple_sharding: If the sharding type should be a tuple.
|
||||
use_sharding_op: whether to create a sharding op on `tensor`.
|
||||
|
||||
Returns:
|
||||
The tensor with Sharding attribute.
|
||||
"""
|
||||
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
|
||||
proto = self._get_or_create_tuple_proto(tensor.op)
|
||||
@ -185,15 +178,11 @@ class Sharding(object):
|
||||
type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
|
||||
else:
|
||||
proto = self._proto
|
||||
attr_value = proto.SerializeToString()
|
||||
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor, sharding=attr_value)
|
||||
attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
|
||||
# TODO(jmolloy): This need to be seriously revisited before declaring this
|
||||
# API available for public use.
|
||||
# pylint: disable=protected-access
|
||||
tensor.op._set_attr('_XlaSharding', attr_value_pb2.AttrValue(s=attr_value))
|
||||
return tensor
|
||||
tensor.op._set_attr('_XlaSharding', attr_value)
|
||||
|
||||
def apply_to_operation(self, operation):
|
||||
"""Applies this Sharding attribute to `operation`.
|
||||
@ -244,7 +233,7 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
|
||||
return to_tensor
|
||||
|
||||
if use_sharding_op:
|
||||
to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
|
||||
to_tensor = tf2xla.sharding(to_tensor)
|
||||
attr_value = attr_value_pb2.AttrValue(s=sharding)
|
||||
# pylint: disable=protected-access
|
||||
to_tensor.op._set_attr('_XlaSharding', attr_value)
|
||||
@ -256,10 +245,12 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
|
||||
|
||||
|
||||
def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False):
|
||||
return Sharding.replicate().apply_to_tensor(
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
Sharding.replicate().apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding,
|
||||
use_sharding_op=use_sharding_op)
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def assign_device(tensor,
|
||||
@ -267,10 +258,13 @@ def assign_device(tensor,
|
||||
assign_tuple_sharding=False,
|
||||
use_sharding_op=False):
|
||||
"""Returns a tensor that has AssignDevice sharding attribute."""
|
||||
return Sharding.assign_device(device).apply_to_tensor(
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
|
||||
Sharding.assign_device(device).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding,
|
||||
use_sharding_op=use_sharding_op)
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def tile(tensor,
|
||||
@ -286,10 +280,13 @@ def tile(tensor,
|
||||
assign_tuple_sharding: If the sharding type should be a tuple.
|
||||
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||
"""
|
||||
return Sharding.tile(tile_assignment).apply_to_tensor(
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
Sharding.tile(tile_assignment).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding,
|
||||
use_sharding_op=use_sharding_op)
|
||||
assign_tuple_sharding=assign_tuple_sharding
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def split(tensor,
|
||||
@ -308,11 +305,12 @@ def split(tensor,
|
||||
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||
input_shape: The full shape of the input tensor.
|
||||
"""
|
||||
return Sharding.split(tensor, split_dimension, num_devices,
|
||||
input_shape).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding,
|
||||
use_sharding_op=use_sharding_op)
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
Sharding.split(
|
||||
tensor, split_dimension, num_devices, input_shape).apply_to_tensor(
|
||||
tensor, assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def partial_tile(tensor, tile_assignment, use_sharding_op=False):
|
||||
@ -326,8 +324,10 @@ def partial_tile(tensor, tile_assignment, use_sharding_op=False):
|
||||
replicated tiles.
|
||||
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||
"""
|
||||
return Sharding.partial_tile(tile_assignment).apply_to_tensor(
|
||||
tensor, use_sharding_op=use_sharding_op)
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def get_op_sharding(op):
|
||||
@ -464,4 +464,7 @@ def mesh_split(tensor,
|
||||
rank.
|
||||
"""
|
||||
sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping)
|
||||
return sharding.apply_to_tensor(tensor, use_sharding_op=use_sharding_op)
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
sharding.apply_to_tensor(tensor)
|
||||
return tensor
|
||||
|
Loading…
Reference in New Issue
Block a user