Add gradient annotation for XlaSharding op.
PiperOrigin-RevId: 322858703 Change-Id: Ie7e6cfa4bf43b466449425a98a7682b0213614e3
This commit is contained in:
parent
21aff5af1f
commit
a6e66d50a4
@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel {
|
||||
~ShardingOp() override = default;
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
auto shape_or = ctx->InputXlaShape(0);
|
||||
xla::XlaOp input;
|
||||
{
|
||||
// The builder might create a broadcast from a constant, so we clear
|
||||
// sharding for the input.
|
||||
xla::XlaScopedShardingAssignment no_sharding(ctx->builder(),
|
||||
absl::nullopt);
|
||||
input = ctx->Input(0);
|
||||
}
|
||||
auto shape_or = ctx->builder()->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, shape_or.status());
|
||||
|
||||
ctx->SetOutput(
|
||||
|
@ -28,6 +28,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding
|
||||
|
||||
@ops.RegisterGradient("XlaSharding")
|
||||
def _sharding_grad(op, grad):
|
||||
del op # Unused
|
||||
return [grad]
|
||||
grad_sharding = gen_xla_ops.xla_sharding(grad)
|
||||
# pylint: disable=protected-access
|
||||
grad_sharding.op._set_attr(
|
||||
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
|
||||
return [grad_sharding]
|
||||
|
||||
|
||||
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
|
||||
|
Loading…
x
Reference in New Issue
Block a user