Handle negative integer powers of 1 and -1.
Now: 1 ^ -X == 1 -1 ^ -X == 1 or -1 depending on the parity of X 0 ^ -X will still incorrectly return 0 though because we cannot return an error. PiperOrigin-RevId: 352000454 Change-Id: I347e79abb3aed61ef15e77d5b586c024b2cad69f
This commit is contained in:
parent
e78e3d36f4
commit
cb89425db9
@ -1070,16 +1070,20 @@ struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
|
||||
static constexpr bool has_errors = true;
|
||||
};
|
||||
|
||||
// Version of safe_pow for integers which returns 0 if RHS is negative instead
|
||||
// of raising an error. For use on GPUs, where we cannot raise an error.
|
||||
// Version of safe_pow for integers which returns 0 if RHS is negative and LHS
|
||||
// is not 1 or -1. For use on GPUs, where we cannot raise an error.
|
||||
template <typename T>
|
||||
struct safe_pow_ignore_error_op {
|
||||
static_assert(std::is_integral<T>::value, "Integer type expected");
|
||||
EIGEN_EMPTY_STRUCT_CTOR(safe_pow_ignore_error_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
|
||||
const T& y) const {
|
||||
if (y < 0) {
|
||||
return T{0};
|
||||
if (TF_PREDICT_FALSE(y < 0)) {
|
||||
if (x == T(-1)) {
|
||||
T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2));
|
||||
return trunc_mod == T(-1) ? T(-1) : T(1);
|
||||
}
|
||||
return x == T(1) ? T(1) : T(0);
|
||||
}
|
||||
return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
|
||||
}
|
||||
|
@ -822,11 +822,12 @@ class BinaryOpTest(test.TestCase):
|
||||
def testPowNegativeExponentGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("Requires GPU")
|
||||
# Negative integer powers return zero on GPUs
|
||||
x = np.array([2, 3, 4]).astype(np.int64)
|
||||
y = np.array([-1, 0, 1]).astype(np.int64)
|
||||
# Negative integer powers return zero on GPUs for abs(LHS) > 1. Negative
|
||||
# integer powers for 1 and -1 will return the correct result.
|
||||
x = np.array([2, 3, 1, -1, -1]).astype(np.int64)
|
||||
y = np.array([-1, 0, -2, -2, -3]).astype(np.int64)
|
||||
z = math_ops.pow(x, y)
|
||||
self.assertAllEqual(self.evaluate(z), [0, 1, 4])
|
||||
self.assertAllEqual(self.evaluate(z), [0, 1, 1, 1, -1])
|
||||
|
||||
|
||||
class ComparisonOpTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user