Handle negative integer powers of 1 and -1.

Now:
1 ^ -X == 1
-1 ^ -X == 1 or -1 depending on the parity of X
0 ^ -X will still incorrectly return 0 though because we cannot return an error.
PiperOrigin-RevId: 352000454
Change-Id: I347e79abb3aed61ef15e77d5b586c024b2cad69f
This commit is contained in:
Tres Popp 2021-01-15 06:58:45 -08:00 committed by TensorFlower Gardener
parent e78e3d36f4
commit cb89425db9
2 changed files with 13 additions and 8 deletions

View File

@ -1070,16 +1070,20 @@ struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
static constexpr bool has_errors = true;
};
// Version of safe_pow for integers which returns 0 if RHS is negative instead
// of raising an error. For use on GPUs, where we cannot raise an error.
// Version of safe_pow for integers which returns 0 if RHS is negative and LHS
// is not 1 or -1. For use on GPUs, where we cannot raise an error.
template <typename T>
struct safe_pow_ignore_error_op {
static_assert(std::is_integral<T>::value, "Integer type expected");
EIGEN_EMPTY_STRUCT_CTOR(safe_pow_ignore_error_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
const T& y) const {
if (y < 0) {
return T{0};
if (TF_PREDICT_FALSE(y < 0)) {
if (x == T(-1)) {
T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2));
return trunc_mod == T(-1) ? T(-1) : T(1);
}
return x == T(1) ? T(1) : T(0);
}
return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
}

View File

@ -822,11 +822,12 @@ class BinaryOpTest(test.TestCase):
def testPowNegativeExponentGpu(self):
if not test_util.is_gpu_available():
self.skipTest("Requires GPU")
# Negative integer powers return zero on GPUs
x = np.array([2, 3, 4]).astype(np.int64)
y = np.array([-1, 0, 1]).astype(np.int64)
# Negative integer powers return zero on GPUs for abs(LHS) > 1. Negative
# integer powers for 1 and -1 will return the correct result.
x = np.array([2, 3, 1, -1, -1]).astype(np.int64)
y = np.array([-1, 0, -2, -2, -3]).astype(np.int64)
z = math_ops.pow(x, y)
self.assertAllEqual(self.evaluate(z), [0, 1, 4])
self.assertAllEqual(self.evaluate(z), [0, 1, 1, 1, -1])
class ComparisonOpTest(test.TestCase):