Reviewer requested changes
This commit is contained in:
parent
9760afc119
commit
83d44dd581
@ -151,7 +151,8 @@ __device__ std::complex<T> impl_sqrt(std::complex<T> x) {
|
||||
T re = x.real(), im = x.imag();
|
||||
T mod_x = sqrt(re * re + im * im);
|
||||
const T root2 = 0.7071067811865475;
|
||||
// we pick the root with the same sign of the imaginary component as the input
|
||||
// We pick the root with the same sign of the imaginary component as
|
||||
// the input.
|
||||
T root[2] = {T(sqrt(mod_x + re) * root2),
|
||||
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
|
||||
// hcc/clang is really weird with its support of complex in device code;
|
||||
@ -256,9 +257,6 @@ __global__ void ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg, T* var, T* mg,
|
||||
var[i] -= mom[i];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace kernel_forward {
|
||||
bool to_pointers(bool x) { return x; }
|
||||
|
Loading…
Reference in New Issue
Block a user