Make tf.min, tf.max always propagate NaNs. Previously these ops followed the semantics of std::min and std:max and forwarded the first argument if either was NaN.

Similarly, make tf.reduce_min, and tf.reduce_max return NaN if any element among thos reduced are NaN.

PiperOrigin-RevId: 338194828
Change-Id: I337729b083468500694bfd7b5846633415be1710
This commit is contained in:
A. Unique TensorFlower 2020-10-20 21:24:11 -07:00 committed by TensorFlower Gardener
parent c0f7f14ba6
commit 666fe19293
13 changed files with 288 additions and 143 deletions

View File

@ -1071,10 +1071,12 @@ struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
};
template <typename T>
struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {};
struct maximum
: base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {};
template <typename T>
struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {};
struct minimum
: base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {};
template <typename T>
struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
@ -1097,9 +1099,7 @@ struct scalar_atan2_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
operator()(const Scalar& y, const Scalar& x) const {
#if GOOGLE_CUDA
return std::atan2(y, x);
#elif TENSORFLOW_USE_ROCM
#if TENSORFLOW_USE_ROCM
return ::atan2(y, x);
#else
return std::atan2(y, x);

View File

@ -73,6 +73,20 @@ struct DividesBy {
__host__ __device__ OUT_T operator()(const T& x) const { return x / divisor; }
};
struct MaxPropagateNaN {
template <typename T>
__host__ __device__ inline T operator()(const T& a, const T& b) const {
return (a != a ? a : (a > b ? a : b));
}
};
struct MinPropagateNaN {
template <typename T>
__host__ __device__ inline T operator()(const T& a, const T& b) const {
return (a != a ? a : (a < b ? a : b));
}
};
#if GOOGLE_CUDA
// TODO(rocm) : enable this once ROCm platform has support for complex datatypes
//
@ -986,15 +1000,19 @@ struct IsSum {
template <typename T, typename Op>
struct IsMax {
constexpr static bool value =
(std::is_same<Op, gpuprim::Max>::value ||
std::is_same<Op, Eigen::internal::MaxReducer<T>>::value);
(std::is_same<Op, MaxPropagateNaN>::value ||
std::is_same<Op, gpuprim::Max>::value ||
std::is_same<
Op, Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>>::value);
};
template <typename T, typename Op>
struct IsMin {
constexpr static bool value =
(std::is_same<Op, gpuprim::Min>::value ||
std::is_same<Op, Eigen::internal::MinReducer<T>>::value);
(std::is_same<Op, MinPropagateNaN>::value ||
std::is_same<Op, gpuprim::Min>::value ||
std::is_same<
Op, Eigen::internal::MinReducer<T, Eigen::PropagateNaN>>::value);
};
template <typename T, typename Op>
@ -1222,41 +1240,47 @@ struct ReduceFunctor<GPUDevice, functor::MeanReducer<Eigen::half>> {
};
template <typename T>
struct ReduceFunctor<GPUDevice, Eigen::internal::MaxReducer<T>> {
struct ReduceFunctor<GPUDevice,
Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>> {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Eigen::internal::MaxReducer<T>& reducer) {
ReduceImpl<T, gpuprim::Max, T*, T*, ReductionAxes>(
static void Reduce(
OpKernelContext* ctx, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>& reducer) {
ReduceImpl<T, MaxPropagateNaN, T*, T*, ReductionAxes>(
ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
in.rank() >= 2 ? in.dimension(1) : 1,
in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
gpuprim::Max());
MaxPropagateNaN());
}
template <typename OUT_T>
static void FillIdentity(const GPUDevice& d, OUT_T out,
const Eigen::internal::MaxReducer<T>& reducer) {
static void FillIdentity(
const GPUDevice& d, OUT_T out,
const Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>& reducer) {
FillIdentityEigenImpl(d, To32Bit(out), reducer);
}
};
template <typename T>
struct ReduceFunctor<GPUDevice, Eigen::internal::MinReducer<T>> {
struct ReduceFunctor<GPUDevice,
Eigen::internal::MinReducer<T, Eigen::PropagateNaN>> {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Eigen::internal::MinReducer<T>& reducer) {
ReduceImpl<T, gpuprim::Min, T*, T*, ReductionAxes>(
static void Reduce(
OpKernelContext* ctx, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Eigen::internal::MinReducer<T, Eigen::PropagateNaN>& reducer) {
ReduceImpl<T, MinPropagateNaN, T*, T*, ReductionAxes>(
ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
in.rank() >= 2 ? in.dimension(1) : 1,
in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
gpuprim::Min());
MinPropagateNaN());
}
template <typename OUT_T>
static void FillIdentity(const GPUDevice& d, OUT_T out,
const Eigen::internal::MinReducer<T>& reducer) {
static void FillIdentity(
const GPUDevice& d, OUT_T out,
const Eigen::internal::MinReducer<T, Eigen::PropagateNaN>& reducer) {
FillIdentityEigenImpl(d, To32Bit(out), reducer);
}
};

View File

@ -44,22 +44,27 @@ typedef TTypes<float>::Tensor::Index Index;
template void ReduceFunctor<GPUDevice, REDUCER>::FillIdentity( \
const GPUDevice& d, TTypes<T>::Vec out, const REDUCER& reducer);
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, R, 1, 1); \
DEFINE(T, R, 2, 1); \
DEFINE(T, R, 3, 1); \
DEFINE(T, R, 3, 2); \
DEFINE_IDENTITY(T, R)
#define SINGLE_ARG(...) __VA_ARGS__
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, SINGLE_ARG(R), 1, 1); \
DEFINE(T, SINGLE_ARG(R), 2, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 2); \
DEFINE_IDENTITY(T, SINGLE_ARG(R))
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MinReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
DEFINE_FOR_ALL_REDUCERS(double);
#undef SINGLE_ARG
#undef DEFINE_FOR_ALL_REDUCERS
#undef DEFINE_FOR_TYPE_AND_R
#undef DEFINE

View File

@ -44,22 +44,27 @@ typedef TTypes<float>::Tensor::Index Index;
template void ReduceFunctor<GPUDevice, REDUCER>::FillIdentity( \
const GPUDevice& d, TTypes<T>::Vec out, const REDUCER& reducer);
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, R, 1, 1); \
DEFINE(T, R, 2, 1); \
DEFINE(T, R, 3, 1); \
DEFINE(T, R, 3, 2); \
DEFINE_IDENTITY(T, R)
#define SINGLE_ARG(...) __VA_ARGS__
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, SINGLE_ARG(R), 1, 1); \
DEFINE(T, SINGLE_ARG(R), 2, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 2); \
DEFINE_IDENTITY(T, SINGLE_ARG(R))
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MinReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
DEFINE_FOR_ALL_REDUCERS(float);
#undef SINGLE_ARG
#undef DEFINE_FOR_ALL_REDUCERS
#undef DEFINE_FOR_TYPE_AND_R
#undef DEFINE

View File

@ -44,23 +44,28 @@ typedef TTypes<float>::Tensor::Index Index;
template void ReduceFunctor<GPUDevice, REDUCER>::FillIdentity( \
const GPUDevice& d, TTypes<T>::Vec out, const REDUCER& reducer);
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, R, 1, 1); \
DEFINE(T, R, 2, 1); \
DEFINE(T, R, 3, 1); \
DEFINE(T, R, 3, 2); \
DEFINE_IDENTITY(T, R)
#define SINGLE_ARG(...) __VA_ARGS__
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, SINGLE_ARG(R), 1, 1); \
DEFINE(T, SINGLE_ARG(R), 2, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 2); \
DEFINE_IDENTITY(T, SINGLE_ARG(R))
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, functor::EuclideanNormReducer<T>); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MinReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
DEFINE_FOR_ALL_REDUCERS(int32);
DEFINE_FOR_ALL_REDUCERS(int64);
#undef SINGLE_ARG
#undef DEFINE_FOR_ALL_REDUCERS
#undef DEFINE_FOR_TYPE_AND_R
#undef DEFINE

View File

@ -44,19 +44,24 @@ typedef TTypes<float>::Tensor::Index Index;
template void ReduceFunctor<GPUDevice, REDUCER>::FillIdentity( \
const GPUDevice& d, TTypes<T>::Vec out, const REDUCER& reducer);
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, R, 1, 1); \
DEFINE(T, R, 2, 1); \
DEFINE(T, R, 3, 1); \
DEFINE(T, R, 3, 2); \
DEFINE_IDENTITY(T, R)
#define SINGLE_ARG(...) __VA_ARGS__
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
#define DEFINE_FOR_TYPE_AND_R(T, R) \
DEFINE(T, SINGLE_ARG(R), 1, 1); \
DEFINE(T, SINGLE_ARG(R), 2, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 1); \
DEFINE(T, SINGLE_ARG(R), 3, 2); \
DEFINE_IDENTITY(T, SINGLE_ARG(R))
#define DEFINE_FOR_ALL_REDUCERS(T) \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MinReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R( \
T, SINGLE_ARG(Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>)); \
DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
DEFINE_FOR_ALL_REDUCERS(Eigen::half);
#undef SINGLE_ARG
#undef DEFINE_FOR_ALL_REDUCERS
#undef DEFINE_FOR_TYPE_AND_R
#undef DEFINE

View File

@ -17,39 +17,43 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, \
Eigen::internal::MaxReducer<type, Eigen::PropagateNaN>>); \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, \
Eigen::internal::MaxReducer<type, Eigen::PropagateNaN>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, \
Eigen::internal::MaxReducer<type, Eigen::PropagateNaN>>); \
REGISTER_KERNEL_BUILDER( \
Name("Max") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, \
Eigen::internal::MaxReducer<type, Eigen::PropagateNaN>>);
REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);

View File

@ -17,39 +17,43 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
#define REGISTER_CPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
ReductionOp<CPUDevice, type, int32, \
Eigen::internal::MinReducer<type, Eigen::PropagateNaN>>); \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx"), \
ReductionOp<CPUDevice, type, int64, \
Eigen::internal::MinReducer<type, Eigen::PropagateNaN>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int32, \
Eigen::internal::MinReducer<type, Eigen::PropagateNaN>>); \
REGISTER_KERNEL_BUILDER( \
Name("Min") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, \
Eigen::internal::MinReducer<type, Eigen::PropagateNaN>>);
REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);

View File

@ -111,7 +111,7 @@ struct LogSumExpReducer {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
finalizeBoth(const T saccum, const Packet& vaccum) const {
auto max_reducer = Eigen::internal::MaxReducer<T>();
auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
auto sum_reducer = Eigen::internal::SumReducer<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto cmp_lt =

View File

@ -767,6 +767,14 @@ class MinMaxOpTest(test.TestCase):
self._compare(x.astype(t), y.astype(t), use_gpu=False)
self._compare(x.astype(t), y.astype(t), use_gpu=True)
def testNaNPropagation(self):
x = np.array([1., np.nan, 1., np.nan], dtype=np.float64)
y = np.array([1., 1., np.nan, np.nan], dtype=np.float64)
for t in [np.float16, np.float32, np.float64]:
with self.subTest(t=t):
self._compare(x.astype(t), y.astype(t), use_gpu=False)
self._compare(x.astype(t), y.astype(t), use_gpu=True)
def testDifferentShapes(self):
x = np.random.rand(1, 3, 2) * 100.
y = np.random.rand(2) * 100. # should broadcast

View File

@ -728,9 +728,7 @@ class MinReductionTest(test.TestCase):
def _compareAll(self, x, reduction_axes):
self._compare(x, reduction_axes, False, use_gpu=True)
self._compare(x, reduction_axes, False, use_gpu=False)
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
@ -739,13 +737,12 @@ class MinReductionTest(test.TestCase):
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@test_util.run_deprecated_v1
def testInfinity(self):
def testSpecialValues(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
for special_value_y in [-np.inf, np.inf]:
np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
self._compareAll(np_arr, None)
for size in range(1, 4):
for arr in itertools.product([-np.inf, 1., np.nan, np.inf],
repeat=size):
self._compareAll(np.array(arr, dtype=dtype), None)
def testFloatReduce3D(self):
# Create a 3D array of floats and reduce across all possible
@ -847,9 +844,7 @@ class MaxReductionTest(test.TestCase):
def _compareAll(self, x, reduction_axes):
self._compare(x, reduction_axes, False, use_gpu=True)
self._compare(x, reduction_axes, False, use_gpu=False)
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
@ -858,13 +853,12 @@ class MaxReductionTest(test.TestCase):
tf_v = self.evaluate(v)
self.assertAllEqual(tf_v, 0)
@test_util.run_deprecated_v1
def testInfinity(self):
def testSpecialValues(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
for special_value_y in [-np.inf, np.inf]:
np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
self._compareAll(np_arr, None)
for size in range(1, 4):
for arr in itertools.product([-np.inf, 1., np.nan, np.inf],
repeat=size):
self._compareAll(np.array(arr, dtype=dtype), None)
def testInt64Reduce3D(self):
# Create a 3D array of int64s and reduce across all possible

View File

@ -105,6 +105,15 @@ class CumsumTest(test.TestCase):
axis = constant_op.constant(0, axis_dtype)
tf_out = math_ops.cumsum(x, axis).eval()
@test_util.run_deprecated_v1
def testNaN(self):
for dtype in (np.float16, np.float32, np.float64):
for nan_idx in range(0, 5):
x = np.arange(1, 6).reshape([5]).astype(dtype)
x[nan_idx] = np.nan
for axis in (-1, 0):
self._compareAll(x, axis)
@test_util.run_deprecated_v1
def test1D(self):
for dtype in self.valid_dtypes:
@ -229,6 +238,15 @@ class CumprodTest(test.TestCase):
axis = constant_op.constant(0, axis_dtype)
tf_out = math_ops.cumprod(x, axis).eval()
@test_util.run_deprecated_v1
def testNaN(self):
for dtype in (np.float16, np.float32, np.float64):
for nan_idx in range(0, 5):
x = np.arange(1, 6).reshape([5]).astype(dtype)
x[nan_idx] = np.nan
for axis in (-1, 0):
self._compareAll(x, axis)
@test_util.run_deprecated_v1
def test1D(self):
for dtype in self.valid_dtypes:

View File

@ -23,3 +23,76 @@ diff -ru a/Eigen/src/Geometry/arch/Geometry_SSE.h b/Eigen/src/Geometry/arch/Geom
return res;
}
};
diff -ru a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -255,49 +255,43 @@
return std::complex<RealScalar>(b, b);
}
-template <typename Packet, typename Op>
-EIGEN_DEVICE_FUNC inline Packet bitwise_helper(const Packet& a, const Packet& b, Op op) {
+/** \internal \returns the bitwise and of \a a and \a b */
+template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
+pand(const Packet& a, const Packet& b) {
const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
Packet c;
unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
for (size_t i = 0; i < sizeof(Packet); ++i) {
- *c_ptr++ = op(*a_ptr++, *b_ptr++);
+ *c_ptr++ = *a_ptr++ & *b_ptr++;
}
return c;
}
-/** \internal \returns the bitwise and of \a a and \a b */
-template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
-pand(const Packet& a, const Packet& b) {
-#if defined(EIGEN_HIP_DEVICE_COMPILE)
- return bitwise_helper(a ,b, std::bit_and<unsigned char>());
-#else
- EIGEN_USING_STD(bit_and);
- return bitwise_helper(a ,b, bit_and<unsigned char>());
-#endif
-}
-
/** \internal \returns the bitwise or of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
por(const Packet& a, const Packet& b) {
-#if defined(EIGEN_HIP_DEVICE_COMPILE)
- return bitwise_helper(a ,b, std::bit_or<unsigned char>());
-#else
- EIGEN_USING_STD(bit_or);
- return bitwise_helper(a ,b, bit_or<unsigned char>());
-#endif
+ const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
+ const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
+ Packet c;
+ unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
+ for (size_t i = 0; i < sizeof(Packet); ++i) {
+ *c_ptr++ = *a_ptr++ | *b_ptr++;
+ }
+ return c;
}
/** \internal \returns the bitwise xor of \a a and \a b */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pxor(const Packet& a, const Packet& b) {
-#if defined(EIGEN_HIP_DEVICE_COMPILE)
- return bitwise_helper(a ,b, std::bit_xor<unsigned char>());
-#else
- EIGEN_USING_STD(bit_xor);
- return bitwise_helper(a ,b, bit_xor<unsigned char>());
-#endif
+ const unsigned char* a_ptr = reinterpret_cast<const unsigned char*>(&a);
+ const unsigned char* b_ptr = reinterpret_cast<const unsigned char*>(&b);
+ Packet c;
+ unsigned char* c_ptr = reinterpret_cast<unsigned char*>(&c);
+ for (size_t i = 0; i < sizeof(Packet); ++i) {
+ *c_ptr++ = *a_ptr++ ^ *b_ptr++;
+ }
+ return c;
}
/** \internal \returns the bitwise and of \a a and not \a b */