Do not generate NaN when backpropagating through operators with singularities at the origin if the input from the layer above is zero. Example: y = sqrt(x=0) = 0, dy/dy = -1/sqrt(x) * dy = -inf * 0 = NaN if dy=0.
To address the problem, we replace division (or multiplication) with an op that always returns 0 if dy is zero. This change is the first of several, and addresses the issue in reciprocal, sqrt, and rsqrt. PiperOrigin-RevId: 238463599
This commit is contained in:
parent
b3975ae622
commit
0ccd78f71e
@ -1149,9 +1149,9 @@ class GbdtTest(test_util.TensorFlowTestCase):
|
|||||||
expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508]
|
expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508]
|
||||||
expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264]
|
expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264]
|
||||||
self.assertArrayNear(expected_leaf_1,
|
self.assertArrayNear(expected_leaf_1,
|
||||||
output.trees[0].nodes[1].leaf.vector.value, 1e-3)
|
output.trees[0].nodes[1].leaf.vector.value, 2e-3)
|
||||||
self.assertArrayNear(expected_leaf_2,
|
self.assertArrayNear(expected_leaf_2,
|
||||||
output.trees[0].nodes[2].leaf.vector.value, 1e-3)
|
output.trees[0].nodes[2].leaf.vector.value, 2e-3)
|
||||||
|
|
||||||
def testTrainFnMulticlassDiagonalHessian(self):
|
def testTrainFnMulticlassDiagonalHessian(self):
|
||||||
"""Tests the GBDT train for multiclass diagonal hessian."""
|
"""Tests the GBDT train for multiclass diagonal hessian."""
|
||||||
|
@ -75,14 +75,19 @@ struct scalar_inverse_gradient_op {
|
|||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||||
operator()(const T& output, const T& output_gradient) const {
|
operator()(const T& output, const T& output_gradient) const {
|
||||||
const T out_conj = numext::conj(output);
|
if (output_gradient == T(0)) {
|
||||||
return -output_gradient * out_conj * out_conj;
|
return T(0);
|
||||||
|
} else {
|
||||||
|
const T out_conj = numext::conj(output);
|
||||||
|
return -out_conj * out_conj * output_gradient;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||||
const Packet out_conj = pconj(output);
|
const Packet out_conj = pconj(output);
|
||||||
return pnegate(pmul(output_gradient, pmul(out_conj, out_conj)));
|
return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)),
|
||||||
|
output_gradient);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -99,15 +104,20 @@ struct scalar_sqrt_gradient_op {
|
|||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||||
operator()(const T& output, const T& output_gradient) const {
|
operator()(const T& output, const T& output_gradient) const {
|
||||||
const T out_conj = numext::conj(output);
|
if (output_gradient == T(0)) {
|
||||||
return static_cast<T>(0.5) * output_gradient / out_conj;
|
return T(0);
|
||||||
|
} else {
|
||||||
|
const T out_conj = numext::conj(output);
|
||||||
|
return (static_cast<T>(0.5) * output_gradient) / out_conj;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||||
const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
|
const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
|
||||||
const Packet out_conj = pconj(output);
|
const Packet out_conj = pconj(output);
|
||||||
return pdiv(pmul(const_half, output_gradient), out_conj);
|
return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj),
|
||||||
|
output_gradient);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -124,17 +134,24 @@ struct scalar_rsqrt_gradient_op {
|
|||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
|
||||||
operator()(const T& output, const T& output_gradient) const {
|
operator()(const T& output, const T& output_gradient) const {
|
||||||
const T out_conj = numext::conj(output);
|
if (output_gradient == T(0)) {
|
||||||
return static_cast<T>(-0.5) * (output_gradient * out_conj) *
|
return T(0);
|
||||||
(out_conj * out_conj);
|
} else {
|
||||||
|
const T out_conj = numext::conj(output);
|
||||||
|
return static_cast<T>(-0.5) * (output_gradient * out_conj) *
|
||||||
|
(out_conj * out_conj);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||||
packetOp(const Packet& output, const Packet& output_gradient) const {
|
packetOp(const Packet& output, const Packet& output_gradient) const {
|
||||||
const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
|
const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
|
||||||
const Packet out_conj = pconj(output);
|
const Packet out_conj = pconj(output);
|
||||||
return pmul(const_half, pmul(pmul(output_gradient, out_conj),
|
auto safe_pmul = [](const Packet& a, const Packet& b) {
|
||||||
pmul(out_conj, out_conj)));
|
return mul_no_nan_op<T>().packetOp(a, b);
|
||||||
|
};
|
||||||
|
return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)),
|
||||||
|
safe_pmul(out_conj, output_gradient));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -543,5 +544,23 @@ class UnaryOpTest(test.TestCase):
|
|||||||
self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
|
self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
|
||||||
|
class SingularGradientOpTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradientAtOrigin(self):
|
||||||
|
ops_to_test = [
|
||||||
|
gen_math_ops.reciprocal, gen_math_ops.rsqrt, gen_math_ops.sqrt
|
||||||
|
]
|
||||||
|
for op in ops_to_test:
|
||||||
|
for dtype in (dtypes_lib.float32, dtypes_lib.float64):
|
||||||
|
with self.cached_session():
|
||||||
|
x = constant_op.constant(0, dtype=dtype)
|
||||||
|
grad_y = constant_op.constant(0, dtype=dtype)
|
||||||
|
y = op(x)
|
||||||
|
g = gradients_impl.gradients(y, [x], grad_ys=grad_y)
|
||||||
|
g_val = self.evaluate(g)
|
||||||
|
self.assertAllEqual(g_val, [0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user