Try to avoid overflows in accumulation results.
This can be done by upcasting to an integer type with more bits. PiperOrigin-RevId: 282935503 Change-Id: Iaf62534f9832ee93ed84e33b9df85068bc5e6941
This commit is contained in:
parent
d393702997
commit
75e5b5d70b
@ -111,6 +111,13 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
|
||||
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
|
||||
return DT_FLOAT;
|
||||
}
|
||||
// Upcast small integer types to 32 bit to avoid overflow.
|
||||
if (dtype == DT_INT8 || dtype == DT_INT16) {
|
||||
return DT_INT32;
|
||||
}
|
||||
if (dtype == DT_UINT8 || dtype == DT_UINT16) {
|
||||
return DT_UINT32;
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user