Merge pull request #33021 from fo40225:fix-msvc163-cuda101
PiperOrigin-RevId: 276092173 Change-Id: Id8be850452bb3c4f93a212eb36e48458d24052c0
This commit is contained in:
commit
e9597468d8
@ -64,7 +64,7 @@ struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
|
|||||||
std::complex<T>* out, const std::complex<T>& val) {
|
std::complex<T>* out, const std::complex<T>& val) {
|
||||||
T* ptr = reinterpret_cast<T*>(out);
|
T* ptr = reinterpret_cast<T*>(out);
|
||||||
GpuAtomicAdd(ptr, val.real());
|
GpuAtomicAdd(ptr, val.real());
|
||||||
GpuAtomicAdd(ptr, val.imag());
|
GpuAtomicAdd(ptr + 1, val.imag());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -72,7 +72,9 @@ template <typename T>
|
|||||||
struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
|
struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
|
||||||
std::complex<T>* out, const std::complex<T>& val) {
|
std::complex<T>* out, const std::complex<T>& val) {
|
||||||
LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val);
|
T* ptr = reinterpret_cast<T*>(out);
|
||||||
|
GpuAtomicSub(ptr, val.real());
|
||||||
|
GpuAtomicSub(ptr + 1, val.imag());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user