Do not generate NaN when backpropagating through operators with singularities at the origin if the input from the layer above is zero. Example: y = sqrt(x=0) = 0, dy/dy = -1/sqrt(x) * dy = -inf * 0 = NaN if dy=0.
To address the problem, we replace division (or multiplication) with an op that always returns 0 if dy is zero. This change is the first of several, and addresses the issue in reciprocal, sqrt, and rsqrt. PiperOrigin-RevId: 238463599
This commit is contained in:
parent
b3975ae622
commit
0ccd78f71e
@ -1149,9 +1149,9 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
||||
expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508]
|
||||
expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264]
|
||||
self.assertArrayNear(expected_leaf_1,
|
||||
output.trees[0].nodes[1].leaf.vector.value, 1e-3)
|
||||
output.trees[0].nodes[1].leaf.vector.value, 2e-3)
|
||||
self.assertArrayNear(expected_leaf_2,
|
||||
output.trees[0].nodes[2].leaf.vector.value, 1e-3)
|
||||
output.trees[0].nodes[2].leaf.vector.value, 2e-3)
|
||||
|
||||
def testTrainFnMulticlassDiagonalHessian(self):
|
||||
"""Tests the GBDT train for multiclass diagonal hessian."""
|
||||
|
@ -75,14 +75,19 @@ struct scalar_inverse_gradient_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||
operator()(const T& output, const T& output_gradient) const {
|
||||
const T out_conj = numext::conj(output);
|
||||
return -output_gradient * out_conj * out_conj;
|
||||
if (output_gradient == T(0)) {
|
||||
return T(0);
|
||||
} else {
|
||||
const T out_conj = numext::conj(output);
|
||||
return -out_conj * out_conj * output_gradient;
|
||||
}
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||
const Packet out_conj = pconj(output);
|
||||
return pnegate(pmul(output_gradient, pmul(out_conj, out_conj)));
|
||||
return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)),
|
||||
output_gradient);
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
@ -99,15 +104,20 @@ struct scalar_sqrt_gradient_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||
operator()(const T& output, const T& output_gradient) const {
|
||||
const T out_conj = numext::conj(output);
|
||||
return static_cast<T>(0.5) * output_gradient / out_conj;
|
||||
if (output_gradient == T(0)) {
|
||||
return T(0);
|
||||
} else {
|
||||
const T out_conj = numext::conj(output);
|
||||
return (static_cast<T>(0.5) * output_gradient) / out_conj;
|
||||
}
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||
const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
|
||||
const Packet out_conj = pconj(output);
|
||||
return pdiv(pmul(const_half, output_gradient), out_conj);
|
||||
return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj),
|
||||
output_gradient);
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
@ -124,17 +134,24 @@ struct scalar_rsqrt_gradient_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||
operator()(const T& output, const T& output_gradient) const {
|
||||
const T out_conj = numext::conj(output);
|
||||
return static_cast<T>(-0.5) * (output_gradient * out_conj) *
|
||||
(out_conj * out_conj);
|
||||
if (output_gradient == T(0)) {
|
||||
return T(0);
|
||||
} else {
|
||||
const T out_conj = numext::conj(output);
|
||||
return static_cast<T>(-0.5) * (output_gradient * out_conj) *
|
||||
(out_conj * out_conj);
|
||||
}
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||
const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
|
||||
const Packet out_conj = pconj(output);
|
||||
return pmul(const_half, pmul(pmul(output_gradient, out_conj),
|
||||
pmul(out_conj, out_conj)));
|
||||
auto safe_pmul = [](const Packet& a, const Packet& b) {
|
||||
return mul_no_nan_op<T>().packetOp(a, b);
|
||||
};
|
||||
return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)),
|
||||
safe_pmul(out_conj, output_gradient));
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
@ -543,5 +544,23 @@ class UnaryOpTest(test.TestCase):
|
||||
self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
class SingularGradientOpTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientAtOrigin(self):
|
||||
ops_to_test = [
|
||||
gen_math_ops.reciprocal, gen_math_ops.rsqrt, gen_math_ops.sqrt
|
||||
]
|
||||
for op in ops_to_test:
|
||||
for dtype in (dtypes_lib.float32, dtypes_lib.float64):
|
||||
with self.cached_session():
|
||||
x = constant_op.constant(0, dtype=dtype)
|
||||
grad_y = constant_op.constant(0, dtype=dtype)
|
||||
y = op(x)
|
||||
g = gradients_impl.gradients(y, [x], grad_ys=grad_y)
|
||||
g_val = self.evaluate(g)
|
||||
self.assertAllEqual(g_val, [0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user