Add XLA implementation for tensor_scatter_nd_min and tensor_scatter_nd_max, and implement gradient for these functions.
PiperOrigin-RevId: 332948708 Change-Id: Ic5e3c138cd04a91a6d1fb1bccad464d146facadf
This commit is contained in:
parent
d526d49e19
commit
7d3979c5ce
tensorflow
compiler
python
@ -2023,6 +2023,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"TensorListSplit",
|
||||
"TensorListStack",
|
||||
"TensorScatterAdd",
|
||||
"TensorScatterMax",
|
||||
"TensorScatterMin",
|
||||
"TensorScatterSub",
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
|
@ -182,6 +182,32 @@ class TensorScatterAddOp : public XlaOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
class TensorScatterMaxOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorScatterMaxOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
CompileTensorScatter(context,
|
||||
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
|
||||
return xla::Max(x, y);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
class TensorScatterMinOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorScatterMinOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
CompileTensorScatter(context,
|
||||
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
|
||||
return xla::Min(x, y);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
class TensorScatterSubOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorScatterSubOp(OpKernelConstruction* context)
|
||||
@ -207,6 +233,8 @@ class TensorScatterUpdateOp : public XlaOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
|
||||
REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
|
||||
REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
|
||||
REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
|
||||
REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
|
||||
|
||||
|
@ -840,6 +840,45 @@ class ScatterNdTensorTest(test.TestCase):
|
||||
self.assertAllEqual(max_result,
|
||||
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2]))
|
||||
|
||||
def testUpdateMinMaxGradients(self):
|
||||
with self.cached_session():
|
||||
x = array_ops.ones([4], dtype=dtypes.float32)
|
||||
indices = constant_op.constant([[1], [2], [3], [3]])
|
||||
updates = constant_op.constant([2.0, 0.5, 1.0, 1.0], dtype=dtypes.float32)
|
||||
|
||||
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||
lambda x: array_ops.tensor_scatter_max(x, indices, updates), [x])
|
||||
# Numerical gradient doesn't work for degenerate values because the
|
||||
# derivative is not continuous. The manually entered gradient divides
|
||||
# the gradient among all contributing elements at the discontinuity.
|
||||
manual = array_ops.reshape(
|
||||
array_ops.matrix_diag([1.0, 0.0, 1.0, 0.3333]), (1, 4, 4))
|
||||
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||
|
||||
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||
lambda x: array_ops.tensor_scatter_min(x, indices, updates), [x])
|
||||
manual = array_ops.reshape(
|
||||
array_ops.matrix_diag([1.0, 1.0, 0.0, 0.3333]), (1, 4, 4))
|
||||
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||
|
||||
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||
lambda updates: array_ops.tensor_scatter_max(x, indices, updates),
|
||||
[updates])
|
||||
manual = constant_op.constant(
|
||||
[[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.3333, 0.3333]]],
|
||||
dtype=dtypes.float32)
|
||||
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||
|
||||
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||
lambda updates: array_ops.tensor_scatter_min(x, indices, updates),
|
||||
[updates])
|
||||
manual = constant_op.constant(
|
||||
[[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.3333, 0.3333]]],
|
||||
dtype=dtypes.float32)
|
||||
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||
|
||||
def testTensorScatterUpdateWithForwarding(self):
|
||||
for dtype in (dtypes.int32, dtypes.float32):
|
||||
|
||||
|
@ -1140,6 +1140,37 @@ def _TensorScatterAddGrad(op, grad):
|
||||
return [tensor_grad, None, updates_grad]
|
||||
|
||||
|
||||
def _TensorScatterMinOrMaxGrad(op, grad):
|
||||
"""Gradient for TensorScatterMin and TensorScatterMax."""
|
||||
indices = op.inputs[1]
|
||||
x = op.inputs[0]
|
||||
y = op.inputs[2]
|
||||
output = op.outputs[0]
|
||||
x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
|
||||
y_output = array_ops.gather_nd(output, indices)
|
||||
y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
|
||||
ys_indicators = array_ops.scatter_nd(indices, y_indicators,
|
||||
array_ops.shape(x))
|
||||
indicators = x_indicators + ys_indicators # All elements are >= 1.
|
||||
# If there are multiple minimum or maximum elements then the gradient will be
|
||||
# divided between them.
|
||||
x_grad = grad * x_indicators / indicators
|
||||
y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
|
||||
return [x_grad, None, y_grad]
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorScatterMax")
|
||||
def _TensorScatterMaxGrad(op, grad):
|
||||
"""Gradient for TensorScatterMax op."""
|
||||
return _TensorScatterMinOrMaxGrad(op, grad)
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorScatterMin")
|
||||
def _TensorScatterMinGrad(op, grad):
|
||||
"""Gradient for TensorScatterMin op."""
|
||||
return _TensorScatterMinOrMaxGrad(op, grad)
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorScatterSub")
|
||||
def _TensorScatterSubGrad(op, grad):
|
||||
indices = op.inputs[1]
|
||||
|
Loading…
Reference in New Issue
Block a user