Add gradient annotation for XlaSharding op.

PiperOrigin-RevId: 322858703
Change-Id: Ie7e6cfa4bf43b466449425a98a7682b0213614e3
This commit is contained in:
Yuanzhong Xu 2020-07-23 13:51:05 -07:00 committed by TensorFlower Gardener
parent 21aff5af1f
commit a6e66d50a4
2 changed files with 15 additions and 4 deletions

View File

@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel {
~ShardingOp() override = default;
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
auto shape_or = ctx->InputXlaShape(0);
xla::XlaOp input;
{
// The builder might create a broadcast from a constant, so we clear
// sharding for the input.
xla::XlaScopedShardingAssignment no_sharding(ctx->builder(),
absl::nullopt);
input = ctx->Input(0);
}
auto shape_or = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, shape_or.status());
ctx->SetOutput(

View File

@ -28,6 +28,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding
@ops.RegisterGradient("XlaSharding")
def _sharding_grad(op, grad):
del op # Unused
return [grad]
grad_sharding = gen_xla_ops.xla_sharding(grad)
# pylint: disable=protected-access
grad_sharding.op._set_attr(
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
return [grad_sharding]
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape