Add gradient annotation for XlaSharding op.
PiperOrigin-RevId: 322858703 Change-Id: Ie7e6cfa4bf43b466449425a98a7682b0213614e3
This commit is contained in:
parent
21aff5af1f
commit
a6e66d50a4
@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel {
|
|||||||
~ShardingOp() override = default;
|
~ShardingOp() override = default;
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaOp input = ctx->Input(0);
|
xla::XlaOp input;
|
||||||
auto shape_or = ctx->InputXlaShape(0);
|
{
|
||||||
|
// The builder might create a broadcast from a constant, so we clear
|
||||||
|
// sharding for the input.
|
||||||
|
xla::XlaScopedShardingAssignment no_sharding(ctx->builder(),
|
||||||
|
absl::nullopt);
|
||||||
|
input = ctx->Input(0);
|
||||||
|
}
|
||||||
|
auto shape_or = ctx->builder()->GetShape(input);
|
||||||
OP_REQUIRES_OK(ctx, shape_or.status());
|
OP_REQUIRES_OK(ctx, shape_or.status());
|
||||||
|
|
||||||
ctx->SetOutput(
|
ctx->SetOutput(
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||||
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding
|
|||||||
|
|
||||||
@ops.RegisterGradient("XlaSharding")
|
@ops.RegisterGradient("XlaSharding")
|
||||||
def _sharding_grad(op, grad):
|
def _sharding_grad(op, grad):
|
||||||
del op # Unused
|
grad_sharding = gen_xla_ops.xla_sharding(grad)
|
||||||
return [grad]
|
# pylint: disable=protected-access
|
||||||
|
grad_sharding.op._set_attr(
|
||||||
|
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
|
||||||
|
return [grad_sharding]
|
||||||
|
|
||||||
|
|
||||||
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
|
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user