Handle negative integer powers of 1 and -1.
Now: 1 ^ -X == 1 -1 ^ -X == 1 or -1 depending on the parity of X 0 ^ -X will still incorrectly return 0 though because we cannot return an error. PiperOrigin-RevId: 352000454 Change-Id: I347e79abb3aed61ef15e77d5b586c024b2cad69f
This commit is contained in:
parent
e78e3d36f4
commit
cb89425db9
@ -1070,16 +1070,20 @@ struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
|
|||||||
static constexpr bool has_errors = true;
|
static constexpr bool has_errors = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Version of safe_pow for integers which returns 0 if RHS is negative instead
|
// Version of safe_pow for integers which returns 0 if RHS is negative and LHS
|
||||||
// of raising an error. For use on GPUs, where we cannot raise an error.
|
// is not 1 or -1. For use on GPUs, where we cannot raise an error.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct safe_pow_ignore_error_op {
|
struct safe_pow_ignore_error_op {
|
||||||
static_assert(std::is_integral<T>::value, "Integer type expected");
|
static_assert(std::is_integral<T>::value, "Integer type expected");
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(safe_pow_ignore_error_op)
|
EIGEN_EMPTY_STRUCT_CTOR(safe_pow_ignore_error_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
|
||||||
const T& y) const {
|
const T& y) const {
|
||||||
if (y < 0) {
|
if (TF_PREDICT_FALSE(y < 0)) {
|
||||||
return T{0};
|
if (x == T(-1)) {
|
||||||
|
T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2));
|
||||||
|
return trunc_mod == T(-1) ? T(-1) : T(1);
|
||||||
|
}
|
||||||
|
return x == T(1) ? T(1) : T(0);
|
||||||
}
|
}
|
||||||
return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
|
return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
|
||||||
}
|
}
|
||||||
|
@ -822,11 +822,12 @@ class BinaryOpTest(test.TestCase):
|
|||||||
def testPowNegativeExponentGpu(self):
|
def testPowNegativeExponentGpu(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("Requires GPU")
|
self.skipTest("Requires GPU")
|
||||||
# Negative integer powers return zero on GPUs
|
# Negative integer powers return zero on GPUs for abs(LHS) > 1. Negative
|
||||||
x = np.array([2, 3, 4]).astype(np.int64)
|
# integer powers for 1 and -1 will return the correct result.
|
||||||
y = np.array([-1, 0, 1]).astype(np.int64)
|
x = np.array([2, 3, 1, -1, -1]).astype(np.int64)
|
||||||
|
y = np.array([-1, 0, -2, -2, -3]).astype(np.int64)
|
||||||
z = math_ops.pow(x, y)
|
z = math_ops.pow(x, y)
|
||||||
self.assertAllEqual(self.evaluate(z), [0, 1, 4])
|
self.assertAllEqual(self.evaluate(z), [0, 1, 1, 1, -1])
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOpTest(test.TestCase):
|
class ComparisonOpTest(test.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user