Avoid using cuda specific GPU function to unbreak ROCM build.
Indirect them through core/util/gpu_functions.h PiperOrigin-RevId: 314622940 Change-Id: I0e5a349f759d0af6ff13acc43b34080a5104c9cc
This commit is contained in:
parent
cfc9a6852e
commit
a1e26e4298
@ -58,14 +58,14 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MAX> {
|
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MAX> {
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||||
CudaAtomicMax(out, val);
|
GpuAtomicMax(out, val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MIN> {
|
struct LeftUpdate<T, scatter_nd_op::UpdateOp::MIN> {
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
|
||||||
CudaAtomicMin(out, val);
|
GpuAtomicMin(out, val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user