Try to avoid overflows in accumulation results.

This can be done by upcasting to an integer type with more bits.

PiperOrigin-RevId: 282935503
Change-Id: Iaf62534f9832ee93ed84e33b9df85068bc5e6941
This commit is contained in:
Adrian Kuegel 2019-11-28 06:34:25 -08:00 committed by TensorFlower Gardener
parent d393702997
commit 75e5b5d70b

View File

@ -111,6 +111,13 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return DT_FLOAT;
}
// Upcast small integer types to 32 bit to avoid overflow.
if (dtype == DT_INT8 || dtype == DT_INT16) {
return DT_INT32;
}
if (dtype == DT_UINT8 || dtype == DT_UINT16) {
return DT_UINT32;
}
return dtype;
}