diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index 3c62dcfc081..46d8051fff1 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -72,6 +72,34 @@ struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, } }; +// Specialization for which we do the reduction in IntermediateType to +// avoid integer overflow. +#define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \ + template <typename Device, typename OUT_T, typename IN_T, \ + typename ReductionAxes> \ + struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \ + functor::MeanReducer<ScalarType>> { \ + void operator()(const Device& d, OUT_T out, IN_T in, \ + const ReductionAxes& reduction_axes, \ + const functor::MeanReducer<ScalarType>& reducer) { \ + static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \ + ""); \ + Eigen::internal::SumReducer<IntermediateType> sum_reducer; \ + out.device(d) = (in.template cast<IntermediateType>().reduce( \ + reduction_axes, sum_reducer) / \ + static_cast<IntermediateType>(in.size() / out.size())) \ + .template cast<ScalarType>(); \ + } \ + } + +CASTING_SPECIALIZATION(uint8, uint64); +CASTING_SPECIALIZATION(uint16, uint64); +CASTING_SPECIALIZATION(uint32, uint64); +CASTING_SPECIALIZATION(int8, int64); +CASTING_SPECIALIZATION(int16, int64); +CASTING_SPECIALIZATION(int32, int64); +#undef CASTING_SPECIALIZATION + // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional // controlled by an attribute. template <typename Device, typename OUT_T, typename IN_T, diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index 152e3a3bbf2..1b5fa201d8f 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -444,6 +444,27 @@ class MeanReductionTest(BaseReductionTest): np_arr = self._makeRandom((2,) * rank, dtypes.uint8) self._compareAllAxes(np_arr) + # This tests the issue reported in b/145030710. + @test_util.run_deprecated_v1 + def testSizeOverflowUint8(self): + np_arr = self._makeRandom((2**8,), dtypes.uint8) + self._compareAllAxes(np_arr) + + @test_util.run_deprecated_v1 + def testSizeOverflowInt8(self): + np_arr = self._makeRandom((2**7,), dtypes.int8) + self._compareAllAxes(np_arr) + + @test_util.run_deprecated_v1 + def testSizeOverflowUint16(self): + np_arr = self._makeRandom((2**16,), dtypes.uint16) + self._compareAllAxes(np_arr) + + @test_util.run_deprecated_v1 + def testSizeOverflowInt16(self): + np_arr = self._makeRandom((2**15,), dtypes.int16) + self._compareAllAxes(np_arr) + @test_util.run_deprecated_v1 def testFloat32(self): for rank in range(1, _MAX_RANK + 1):