Avoid using cuda specific GPU function to unbreak ROCM build.
Indirect them through core/util/gpu_functions.h PiperOrigin-RevId: 314622940 Change-Id: I0e5a349f759d0af6ff13acc43b34080a5104c9cc
This commit is contained in:
parent
cfc9a6852e
commit
a1e26e4298
@ -58,14 +58,14 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
|
||||
template <typename T>
|
||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MAX> {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||
CudaAtomicMax(out, val);
|
||||
GpuAtomicMax(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MIN> {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||
CudaAtomicMin(out, val);
|
||||
GpuAtomicMin(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user