Do mean reductions for integer types in 64 bit to mitigate overflow in the sum and/or denominator.

PiperOrigin-RevId: 282847676
Change-Id: I267823932b2c3e1f9916ea0edfcce3efb5d4430a
This commit is contained in:
A. Unique TensorFlower 2019-11-27 15:54:18 -08:00 committed by TensorFlower Gardener
parent 74faaeb08f
commit 23fde233bf
2 changed files with 49 additions and 0 deletions

View File

@ -72,6 +72,34 @@ struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
}
};
// Specialization for which we do the reduction in IntermediateType to
// avoid integer overflow.
#define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \
template <typename Device, typename OUT_T, typename IN_T, \
typename ReductionAxes> \
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
functor::MeanReducer<ScalarType>> { \
void operator()(const Device& d, OUT_T out, IN_T in, \
const ReductionAxes& reduction_axes, \
const functor::MeanReducer<ScalarType>& reducer) { \
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
""); \
Eigen::internal::SumReducer<IntermediateType> sum_reducer; \
out.device(d) = (in.template cast<IntermediateType>().reduce( \
reduction_axes, sum_reducer) / \
static_cast<IntermediateType>(in.size() / out.size())) \
.template cast<ScalarType>(); \
} \
}
CASTING_SPECIALIZATION(uint8, uint64);
CASTING_SPECIALIZATION(uint16, uint64);
CASTING_SPECIALIZATION(uint32, uint64);
CASTING_SPECIALIZATION(int8, int64);
CASTING_SPECIALIZATION(int16, int64);
CASTING_SPECIALIZATION(int32, int64);
#undef CASTING_SPECIALIZATION
// TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
// controlled by an attribute.
template <typename Device, typename OUT_T, typename IN_T,

View File

@ -444,6 +444,27 @@ class MeanReductionTest(BaseReductionTest):
np_arr = self._makeRandom((2,) * rank, dtypes.uint8)
self._compareAllAxes(np_arr)
# This tests the issue reported in b/145030710.
@test_util.run_deprecated_v1
def testSizeOverflowUint8(self):
np_arr = self._makeRandom((2**8,), dtypes.uint8)
self._compareAllAxes(np_arr)
@test_util.run_deprecated_v1
def testSizeOverflowInt8(self):
np_arr = self._makeRandom((2**7,), dtypes.int8)
self._compareAllAxes(np_arr)
@test_util.run_deprecated_v1
def testSizeOverflowUint16(self):
np_arr = self._makeRandom((2**16,), dtypes.uint16)
self._compareAllAxes(np_arr)
@test_util.run_deprecated_v1
def testSizeOverflowInt16(self):
np_arr = self._makeRandom((2**15,), dtypes.int16)
self._compareAllAxes(np_arr)
@test_util.run_deprecated_v1
def testFloat32(self):
for rank in range(1, _MAX_RANK + 1):