Add XLA implementation for tensor_scatter_nd_min and tensor_scatter_nd_max, and implement gradient for these functions.
PiperOrigin-RevId: 332948708 Change-Id: Ic5e3c138cd04a91a6d1fb1bccad464d146facadf
This commit is contained in:
parent
d526d49e19
commit
7d3979c5ce
@ -2023,6 +2023,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"TensorListSplit",
|
"TensorListSplit",
|
||||||
"TensorListStack",
|
"TensorListStack",
|
||||||
"TensorScatterAdd",
|
"TensorScatterAdd",
|
||||||
|
"TensorScatterMax",
|
||||||
|
"TensorScatterMin",
|
||||||
"TensorScatterSub",
|
"TensorScatterSub",
|
||||||
"TensorScatterUpdate",
|
"TensorScatterUpdate",
|
||||||
"TridiagonalSolve",
|
"TridiagonalSolve",
|
||||||
|
|||||||
@ -182,6 +182,32 @@ class TensorScatterAddOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TensorScatterMaxOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorScatterMaxOp(OpKernelConstruction* context)
|
||||||
|
: XlaOpKernel(context) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
|
CompileTensorScatter(context,
|
||||||
|
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
|
||||||
|
return xla::Max(x, y);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TensorScatterMinOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit TensorScatterMinOp(OpKernelConstruction* context)
|
||||||
|
: XlaOpKernel(context) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
|
CompileTensorScatter(context,
|
||||||
|
[](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
|
||||||
|
return xla::Min(x, y);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class TensorScatterSubOp : public XlaOpKernel {
|
class TensorScatterSubOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TensorScatterSubOp(OpKernelConstruction* context)
|
explicit TensorScatterSubOp(OpKernelConstruction* context)
|
||||||
@ -207,6 +233,8 @@ class TensorScatterUpdateOp : public XlaOpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
|
REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
|
||||||
|
REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
|
||||||
|
REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
|
||||||
REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
|
REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
|
||||||
REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
|
REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
|
||||||
|
|
||||||
|
|||||||
@ -840,6 +840,45 @@ class ScatterNdTensorTest(test.TestCase):
|
|||||||
self.assertAllEqual(max_result,
|
self.assertAllEqual(max_result,
|
||||||
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2]))
|
constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2]))
|
||||||
|
|
||||||
|
def testUpdateMinMaxGradients(self):
|
||||||
|
with self.cached_session():
|
||||||
|
x = array_ops.ones([4], dtype=dtypes.float32)
|
||||||
|
indices = constant_op.constant([[1], [2], [3], [3]])
|
||||||
|
updates = constant_op.constant([2.0, 0.5, 1.0, 1.0], dtype=dtypes.float32)
|
||||||
|
|
||||||
|
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||||
|
lambda x: array_ops.tensor_scatter_max(x, indices, updates), [x])
|
||||||
|
# Numerical gradient doesn't work for degenerate values because the
|
||||||
|
# derivative is not continuous. The manually entered gradient divides
|
||||||
|
# the gradient among all contributing elements at the discontinuity.
|
||||||
|
manual = array_ops.reshape(
|
||||||
|
array_ops.matrix_diag([1.0, 0.0, 1.0, 0.3333]), (1, 4, 4))
|
||||||
|
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||||
|
|
||||||
|
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||||
|
lambda x: array_ops.tensor_scatter_min(x, indices, updates), [x])
|
||||||
|
manual = array_ops.reshape(
|
||||||
|
array_ops.matrix_diag([1.0, 1.0, 0.0, 0.3333]), (1, 4, 4))
|
||||||
|
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||||
|
|
||||||
|
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||||
|
lambda updates: array_ops.tensor_scatter_max(x, indices, updates),
|
||||||
|
[updates])
|
||||||
|
manual = constant_op.constant(
|
||||||
|
[[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.3333, 0.3333]]],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||||
|
|
||||||
|
theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||||
|
lambda updates: array_ops.tensor_scatter_min(x, indices, updates),
|
||||||
|
[updates])
|
||||||
|
manual = constant_op.constant(
|
||||||
|
[[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.3333, 0.3333]]],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
self.assertAllClose(theoretical, manual, 5e-4, 5e-4)
|
||||||
|
|
||||||
def testTensorScatterUpdateWithForwarding(self):
|
def testTensorScatterUpdateWithForwarding(self):
|
||||||
for dtype in (dtypes.int32, dtypes.float32):
|
for dtype in (dtypes.int32, dtypes.float32):
|
||||||
|
|
||||||
|
|||||||
@ -1140,6 +1140,37 @@ def _TensorScatterAddGrad(op, grad):
|
|||||||
return [tensor_grad, None, updates_grad]
|
return [tensor_grad, None, updates_grad]
|
||||||
|
|
||||||
|
|
||||||
|
def _TensorScatterMinOrMaxGrad(op, grad):
|
||||||
|
"""Gradient for TensorScatterMin and TensorScatterMax."""
|
||||||
|
indices = op.inputs[1]
|
||||||
|
x = op.inputs[0]
|
||||||
|
y = op.inputs[2]
|
||||||
|
output = op.outputs[0]
|
||||||
|
x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
|
||||||
|
y_output = array_ops.gather_nd(output, indices)
|
||||||
|
y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
|
||||||
|
ys_indicators = array_ops.scatter_nd(indices, y_indicators,
|
||||||
|
array_ops.shape(x))
|
||||||
|
indicators = x_indicators + ys_indicators # All elements are >= 1.
|
||||||
|
# If there are multiple minimum or maximum elements then the gradient will be
|
||||||
|
# divided between them.
|
||||||
|
x_grad = grad * x_indicators / indicators
|
||||||
|
y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
|
||||||
|
return [x_grad, None, y_grad]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("TensorScatterMax")
|
||||||
|
def _TensorScatterMaxGrad(op, grad):
|
||||||
|
"""Gradient for TensorScatterMax op."""
|
||||||
|
return _TensorScatterMinOrMaxGrad(op, grad)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("TensorScatterMin")
|
||||||
|
def _TensorScatterMinGrad(op, grad):
|
||||||
|
"""Gradient for TensorScatterMin op."""
|
||||||
|
return _TensorScatterMinOrMaxGrad(op, grad)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("TensorScatterSub")
|
@ops.RegisterGradient("TensorScatterSub")
|
||||||
def _TensorScatterSubGrad(op, grad):
|
def _TensorScatterSubGrad(op, grad):
|
||||||
indices = op.inputs[1]
|
indices = op.inputs[1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user