Add an explicit sharding attribute to XlaSharding op. This fixes the issue which the user added attribute cannot be found when computing gradients.

PiperOrigin-RevId: 353364146
Change-Id: Iee02f8df1fb0c987bf5866dfb112361d9c9e374e
This commit is contained in:
A. Unique TensorFlower 2021-01-22 19:52:35 -08:00 committed by TensorFlower Gardener
parent 01a9c4f7c0
commit d56de8e17e
7 changed files with 71 additions and 71 deletions

View File

@ -1013,7 +1013,6 @@ An op which shards the input based on the given sharding attribute.
let arguments = (ins
TF_Tensor:$input,
DefaultValuedAttr<StrAttr, "">:$sharding,
OptionalAttr<StrAttr>:$_XlaSharding
);

View File

@ -1636,7 +1636,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
}
func @tpu0_func(%arg0: tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %1, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -1704,7 +1704,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -1772,7 +1772,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -1815,7 +1815,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -1855,7 +1855,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi1>, tensor<*xi32>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %3, %4 : tensor<*xi1>, tensor<*xi32>
}
}
@ -1961,7 +1961,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -2068,7 +2068,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -2152,7 +2152,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -2237,7 +2237,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
@ -2321,7 +2321,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}

View File

@ -75,7 +75,7 @@ func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
return %1 : tensor<*xi32>
}
@ -97,11 +97,11 @@ func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: te
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%2, %3 = "tf.A"(%0, %1) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %5 : tensor<*xi32> , tensor<*xi1>
}
@ -123,11 +123,11 @@ func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%3, %4 = "tf.A"(%1, %2) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %5, %6 : tensor<*xi32> , tensor<*xi1>
}
@ -149,13 +149,13 @@ func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
%2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%3 = "tf.Identity"(%2) : (tensor<32xf32>) -> tensor<32xf32>
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
%5, %6 = "tf.A"(%1, %3) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %7, %8 : tensor<*xi32> , tensor<*xi1>
}
@ -178,11 +178,11 @@ func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%4, %5 = "tf.A"(%2, %3) : (tensor<*xi1>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %6, %7 : tensor<*xi32> , tensor<*xi1>
}
@ -208,22 +208,22 @@ func @func_with_device_training_loop(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>)
%2 = "tf.PartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_func_body} : (tensor<*xi1>) -> (tensor<i32>)
%3, %4 = "tf.A"(%1#0, %2) : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, tensor<*xi1>)
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C", sharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F", sharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %5, %6 : tensor<*xi32> , tensor<*xi1>
}
// CHECK-LABEL: func @func_body
func @func_body(%arg0: tensor<*xi32>)-> (tensor<*xi32>, tensor<*xi1>) {
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%2, %3 = "tf.C"(%1) : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
return %2, %3 : tensor<*xi32> , tensor<*xi1>
}
// CHECK-LABEL: func @pcall_func_body
func @pcall_func_body(%arg0: tensor<*xi1>) -> tensor<i32> {
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\04\05\06", sharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%2 = "tf.D"(%1) : (tensor<*xi1>) -> (tensor<i32>)
return %2 : tensor<i32>
}
@ -248,7 +248,7 @@ func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
}
func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03", sharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.Identity"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
return %1 : tensor<*xi32>
}

View File

@ -4804,7 +4804,7 @@ func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xb
// CHECK-LABEL: xla_sharding
func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""}
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
%0 = "tf.XlaSharding"(%arg0) {_XlaSharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
return %0 : tensor<4x16xf32>
}

View File

@ -899,7 +899,6 @@ REGISTER_OP("XlaSharding")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("sharding: string = ''")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An op which shards the input based on the given sharding attribute.

View File

@ -444,11 +444,10 @@ sharding = gen_xla_ops.xla_sharding
@ops.RegisterGradient("XlaSharding")
def _sharding_grad(op, grad):
sharding_attr = op.get_attr("sharding")
grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
grad_sharding = gen_xla_ops.xla_sharding(grad)
# pylint: disable=protected-access
grad_sharding.op._set_attr("_XlaSharding",
attr_value_pb2.AttrValue(s=sharding_attr))
grad_sharding.op._set_attr(
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
return [grad_sharding]

View File

@ -161,19 +161,12 @@ class Sharding(object):
tile_assignment_dimensions=tile_assignment_dims,
tile_assignment_devices=range(num_devices)))
def apply_to_tensor(self,
tensor,
assign_tuple_sharding=False,
use_sharding_op=False):
def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
"""Applies this Sharding attribute to `tensor`.
Args:
tensor: A tf.Tensor to split.
assign_tuple_sharding: If the sharding type should be a tuple.
use_sharding_op: whether to create a sharding op on `tensor`.
Returns:
The tensor with Sharding attribute.
"""
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
proto = self._get_or_create_tuple_proto(tensor.op)
@ -185,15 +178,11 @@ class Sharding(object):
type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
else:
proto = self._proto
attr_value = proto.SerializeToString()
if use_sharding_op:
tensor = tf2xla.sharding(tensor, sharding=attr_value)
attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
# TODO(jmolloy): This need to be seriously revisited before declaring this
# API available for public use.
# pylint: disable=protected-access
tensor.op._set_attr('_XlaSharding', attr_value_pb2.AttrValue(s=attr_value))
return tensor
tensor.op._set_attr('_XlaSharding', attr_value)
def apply_to_operation(self, operation):
"""Applies this Sharding attribute to `operation`.
@ -244,7 +233,7 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
return to_tensor
if use_sharding_op:
to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
to_tensor = tf2xla.sharding(to_tensor)
attr_value = attr_value_pb2.AttrValue(s=sharding)
# pylint: disable=protected-access
to_tensor.op._set_attr('_XlaSharding', attr_value)
@ -256,10 +245,12 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False):
return Sharding.replicate().apply_to_tensor(
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
Sharding.replicate().apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding,
use_sharding_op=use_sharding_op)
assign_tuple_sharding=assign_tuple_sharding)
return tensor
def assign_device(tensor,
@ -267,10 +258,13 @@ def assign_device(tensor,
assign_tuple_sharding=False,
use_sharding_op=False):
"""Returns a tensor that has AssignDevice sharding attribute."""
return Sharding.assign_device(device).apply_to_tensor(
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
Sharding.assign_device(device).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding,
use_sharding_op=use_sharding_op)
assign_tuple_sharding=assign_tuple_sharding)
return tensor
def tile(tensor,
@ -286,10 +280,13 @@ def tile(tensor,
assign_tuple_sharding: If the sharding type should be a tuple.
use_sharding_op: If true, adds a sharding op to set the sharding.
"""
return Sharding.tile(tile_assignment).apply_to_tensor(
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
Sharding.tile(tile_assignment).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding,
use_sharding_op=use_sharding_op)
assign_tuple_sharding=assign_tuple_sharding
)
return tensor
def split(tensor,
@ -308,11 +305,12 @@ def split(tensor,
use_sharding_op: If true, adds a sharding op to set the sharding.
input_shape: The full shape of the input tensor.
"""
return Sharding.split(tensor, split_dimension, num_devices,
input_shape).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding,
use_sharding_op=use_sharding_op)
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
Sharding.split(
tensor, split_dimension, num_devices, input_shape).apply_to_tensor(
tensor, assign_tuple_sharding=assign_tuple_sharding)
return tensor
def partial_tile(tensor, tile_assignment, use_sharding_op=False):
@ -326,8 +324,10 @@ def partial_tile(tensor, tile_assignment, use_sharding_op=False):
replicated tiles.
use_sharding_op: If true, adds a sharding op to set the sharding.
"""
return Sharding.partial_tile(tile_assignment).apply_to_tensor(
tensor, use_sharding_op=use_sharding_op)
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor)
return tensor
def get_op_sharding(op):
@ -464,4 +464,7 @@ def mesh_split(tensor,
rank.
"""
sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping)
return sharding.apply_to_tensor(tensor, use_sharding_op=use_sharding_op)
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
sharding.apply_to_tensor(tensor)
return tensor