Do mean reductions for integer types in 64 bit to mitigate overflow in the sum and/or denominator.
PiperOrigin-RevId: 282847676 Change-Id: I267823932b2c3e1f9916ea0edfcce3efb5d4430a
This commit is contained in:
parent
74faaeb08f
commit
23fde233bf
@ -72,6 +72,34 @@ struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for which we do the reduction in IntermediateType to
|
||||
// avoid integer overflow.
|
||||
#define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \
|
||||
template <typename Device, typename OUT_T, typename IN_T, \
|
||||
typename ReductionAxes> \
|
||||
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
|
||||
functor::MeanReducer<ScalarType>> { \
|
||||
void operator()(const Device& d, OUT_T out, IN_T in, \
|
||||
const ReductionAxes& reduction_axes, \
|
||||
const functor::MeanReducer<ScalarType>& reducer) { \
|
||||
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
|
||||
""); \
|
||||
Eigen::internal::SumReducer<IntermediateType> sum_reducer; \
|
||||
out.device(d) = (in.template cast<IntermediateType>().reduce( \
|
||||
reduction_axes, sum_reducer) / \
|
||||
static_cast<IntermediateType>(in.size() / out.size())) \
|
||||
.template cast<ScalarType>(); \
|
||||
} \
|
||||
}
|
||||
|
||||
CASTING_SPECIALIZATION(uint8, uint64);
|
||||
CASTING_SPECIALIZATION(uint16, uint64);
|
||||
CASTING_SPECIALIZATION(uint32, uint64);
|
||||
CASTING_SPECIALIZATION(int8, int64);
|
||||
CASTING_SPECIALIZATION(int16, int64);
|
||||
CASTING_SPECIALIZATION(int32, int64);
|
||||
#undef CASTING_SPECIALIZATION
|
||||
|
||||
// TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
|
||||
// controlled by an attribute.
|
||||
template <typename Device, typename OUT_T, typename IN_T,
|
||||
|
@ -444,6 +444,27 @@ class MeanReductionTest(BaseReductionTest):
|
||||
np_arr = self._makeRandom((2,) * rank, dtypes.uint8)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
# This tests the issue reported in b/145030710.
|
||||
@test_util.run_deprecated_v1
|
||||
def testSizeOverflowUint8(self):
|
||||
np_arr = self._makeRandom((2**8,), dtypes.uint8)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSizeOverflowInt8(self):
|
||||
np_arr = self._makeRandom((2**7,), dtypes.int8)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSizeOverflowUint16(self):
|
||||
np_arr = self._makeRandom((2**16,), dtypes.uint16)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSizeOverflowInt16(self):
|
||||
np_arr = self._makeRandom((2**15,), dtypes.int16)
|
||||
self._compareAllAxes(np_arr)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloat32(self):
|
||||
for rank in range(1, _MAX_RANK + 1):
|
||||
|
Loading…
Reference in New Issue
Block a user