From a6e66d50a4b419a4b21c2fecb87252f93e531767 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Thu, 23 Jul 2020 13:51:05 -0700 Subject: [PATCH] Add gradient annotation for XlaSharding op. PiperOrigin-RevId: 322858703 Change-Id: Ie7e6cfa4bf43b466449425a98a7682b0213614e3 --- tensorflow/compiler/tf2xla/kernels/sharding_op.cc | 11 +++++++++-- tensorflow/compiler/tf2xla/python/xla.py | 8 ++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc index 1047580264b..da268fe283c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc @@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel { ~ShardingOp() override = default; void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp input = ctx->Input(0); - auto shape_or = ctx->InputXlaShape(0); + xla::XlaOp input; + { + // The builder might create a broadcast from a constant, so we clear + // sharding for the input. + xla::XlaScopedShardingAssignment no_sharding(ctx->builder(), + absl::nullopt); + input = ctx->Input(0); + } + auto shape_or = ctx->builder()->GetShape(input); OP_REQUIRES_OK(ctx, shape_or.status()); ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 0ebca2d546f..846dafa2570 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding @ops.RegisterGradient("XlaSharding") def _sharding_grad(op, grad): - del op # Unused - return [grad] + grad_sharding = gen_xla_ops.xla_sharding(grad) + # pylint: disable=protected-access + grad_sharding.op._set_attr( + "_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding"))) + return [grad_sharding] spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape