Unverified Commit f3bbee40 authored by Christian Sigg's avatar Christian Sigg Committed by GitHub
Browse files

Improve ROCm's sqrt and rsqrt for std::complex.

No related merge requests found
Showing with 16 additions and 27 deletions
+16 -27
......@@ -148,41 +148,30 @@ __device__ Eigen::half impl_rsqrt(Eigen::half x) {
template <class T>
__device__ std::complex<T> impl_sqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T mod_x = sqrt(re * re + im * im);
const T root2 = 0.7071067811865475;
T a = x.real();
T b = x.imag();
T r = impl_sqrt(norm(x));
// returns sqrt(0.5 * (r - v)) where v may be close to r.
auto helper = [&](const T& v) {
T diff = r - v;
if (diff < T(1e-5) * r) {
// |a| >> |b|, use rsqrt(1+x) ~= 1 + x/2.
return T(0.5) * fabs(b) * impl_rsqrt(r);
}
return sqrt(T(0.5) * diff);
};
// We pick the root with the same sign of the imaginary component as
// the input.
T root[2] = {T(sqrt(mod_x + re) * root2),
T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
T result[2] = {helper(-a),
[&](const T& v) { return b >= 0 ? v : -v; }(helper(a))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return *(reinterpret_cast<std::complex<T>*>(&root));
}
template <class T>
__device__ T rsqrt_helper(T x) {
return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
return reinterpret_cast<const std::complex<T>&>(result);
}
template <class T>
__device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
T re = x.real(), im = x.imag();
T r = rsqrt(re * re + im * im);
T ar2 = re * r * r;
const T root2 = 0.7071067811865475;
T root[2];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
? rsqrt_helper(im * im * r * r)
: 1 + re * r)) *
root2;
root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
? rsqrt_helper(im * im * r * r)
: 1 - re * r)) *
root2 * (im >= 0 ? -1. : 1.);
return *(reinterpret_cast<std::complex<T>*>(&root));
return conj(impl_sqrt(x)) * impl_rsqrt(norm(x));
}
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment