Add XLA implementation for tensor_scatter_nd_min and tensor_scatter_nd_max, and implement gradient for these functions.

PiperOrigin-RevId: 332948708
Change-Id: Ic5e3c138cd04a91a6d1fb1bccad464d146facadf
This commit is contained in:
A. Unique TensorFlower 2020-09-21 15:30:57 -07:00 committed by TensorFlower Gardener
parent d526d49e19
commit 7d3979c5ce
4 changed files with 100 additions and 0 deletions
tensorflow
compiler
python
kernel_tests/array_ops
ops

View File

@ -2023,6 +2023,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"TensorListSplit",
"TensorListStack",
"TensorScatterAdd",
"TensorScatterMax",
"TensorScatterMin",
"TensorScatterSub",
"TensorScatterUpdate",
"TridiagonalSolve",

View File

@ -182,6 +182,32 @@ class TensorScatterAddOp : public XlaOpKernel {
}
};
class TensorScatterMaxOp : public XlaOpKernel {
public:
explicit TensorScatterMaxOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
CompileTensorScatter(context,
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
return xla::Max(x, y);
});
}
};
class TensorScatterMinOp : public XlaOpKernel {
public:
explicit TensorScatterMinOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
CompileTensorScatter(context,
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
return xla::Min(x, y);
});
}
};
class TensorScatterSubOp : public XlaOpKernel {
public:
explicit TensorScatterSubOp(OpKernelConstruction* context)
@ -207,6 +233,8 @@ class TensorScatterUpdateOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);

View File

@ -840,6 +840,45 @@ class ScatterNdTensorTest(test.TestCase):
self.assertAllEqual(max_result,
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2]))
def testUpdateMinMaxGradients(self):
with self.cached_session():
x = array_ops.ones([4], dtype=dtypes.float32)
indices = constant_op.constant([[1], [2], [3], [3]])
updates = constant_op.constant([2.0, 0.5, 1.0, 1.0], dtype=dtypes.float32)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_max(x, indices, updates), [x])
# Numerical gradient doesn't work for degenerate values because the
# derivative is not continuous. The manually entered gradient divides
# the gradient among all contributing elements at the discontinuity.
manual = array_ops.reshape(
array_ops.matrix_diag([1.0, 0.0, 1.0, 0.3333]), (1, 4, 4))
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda x: array_ops.tensor_scatter_min(x, indices, updates), [x])
manual = array_ops.reshape(
array_ops.matrix_diag([1.0, 1.0, 0.0, 0.3333]), (1, 4, 4))
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_max(x, indices, updates),
[updates])
manual = constant_op.constant(
[[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.3333, 0.3333]]],
dtype=dtypes.float32)
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
theoretical, _ = gradient_checker_v2.compute_gradient(
lambda updates: array_ops.tensor_scatter_min(x, indices, updates),
[updates])
manual = constant_op.constant(
[[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.3333, 0.3333]]],
dtype=dtypes.float32)
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
def testTensorScatterUpdateWithForwarding(self):
for dtype in (dtypes.int32, dtypes.float32):

View File

@ -1140,6 +1140,37 @@ def _TensorScatterAddGrad(op, grad):
return [tensor_grad, None, updates_grad]
def _TensorScatterMinOrMaxGrad(op, grad):
"""Gradient for TensorScatterMin and TensorScatterMax."""
indices = op.inputs[1]
x = op.inputs[0]
y = op.inputs[2]
output = op.outputs[0]
x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
y_output = array_ops.gather_nd(output, indices)
y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
ys_indicators = array_ops.scatter_nd(indices, y_indicators,
array_ops.shape(x))
indicators = x_indicators + ys_indicators # All elements are >= 1.
# If there are multiple minimum or maximum elements then the gradient will be
# divided between them.
x_grad = grad * x_indicators / indicators
y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
return [x_grad, None, y_grad]
@ops.RegisterGradient("TensorScatterMax")
def _TensorScatterMaxGrad(op, grad):
"""Gradient for TensorScatterMax op."""
return _TensorScatterMinOrMaxGrad(op, grad)
@ops.RegisterGradient("TensorScatterMin")
def _TensorScatterMinGrad(op, grad):
"""Gradient for TensorScatterMin op."""
return _TensorScatterMinOrMaxGrad(op, grad)
@ops.RegisterGradient("TensorScatterSub")
def _TensorScatterSubGrad(op, grad):
indices = op.inputs[1]