Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
seo young Joung
tensorflow
Commits
f3bbee40
Unverified
Commit
f3bbee40
authored
4 years ago
by
Christian Sigg
Committed by
GitHub
4 years ago
Browse files
Options
Download
Email Patches
Plain Diff
Improve ROCm's sqrt and rsqrt for std::complex.
parent
f3b556a9
rocm_sqrt
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
tensorflow/core/kernels/training_ops_gpu.cu.cc
+16
-27
tensorflow/core/kernels/training_ops_gpu.cu.cc
with
16 additions
and
27 deletions
+16
-27
tensorflow/core/kernels/training_ops_gpu.cu.cc
View file @
f3bbee40
...
...
@@ -148,41 +148,30 @@ __device__ Eigen::half impl_rsqrt(Eigen::half x) {
template
<
class
T
>
__device__
std
::
complex
<
T
>
impl_sqrt
(
std
::
complex
<
T
>
x
)
{
T
re
=
x
.
real
(),
im
=
x
.
imag
();
T
mod_x
=
sqrt
(
re
*
re
+
im
*
im
);
const
T
root2
=
0.7071067811865475
;
T
a
=
x
.
real
();
T
b
=
x
.
imag
();
T
r
=
impl_sqrt
(
norm
(
x
));
// returns sqrt(0.5 * (r - v)) where v may be close to r.
auto
helper
=
[
&
](
const
T
&
v
)
{
T
diff
=
r
-
v
;
if
(
diff
<
T
(
1e-5
)
*
r
)
{
// |a| >> |b|, use rsqrt(1+x) ~= 1 + x/2.
return
T
(
0.5
)
*
fabs
(
b
)
*
impl_rsqrt
(
r
);
}
return
sqrt
(
T
(
0.5
)
*
diff
);
};
// We pick the root with the same sign of the imaginary component as
// the input.
T
r
oo
t
[
2
]
=
{
T
(
sqrt
(
mod_x
+
re
)
*
root2
),
T
(
sqrt
(
mod_x
-
re
)
*
r
oot2
*
(
im
>=
0
?
1.
:
-
1.
))};
T
r
esul
t
[
2
]
=
{
helper
(
-
a
),
[
&
](
const
T
&
v
)
{
r
eturn
b
>=
0
?
v
:
-
v
;
}(
helper
(
a
))};
// hcc/clang is really weird with its support of complex in device code;
// for some reason it does not permit a 2-argument constructor
return
*
(
reinterpret_cast
<
std
::
complex
<
T
>*>
(
&
root
));
}
template
<
class
T
>
__device__
T
rsqrt_helper
(
T
x
)
{
return
0.5
*
x
+
0.125
*
x
*
x
+
0.0625
*
x
*
x
*
x
;
return
reinterpret_cast
<
const
std
::
complex
<
T
>&>
(
result
);
}
template
<
class
T
>
__device__
std
::
complex
<
T
>
impl_rsqrt
(
std
::
complex
<
T
>
x
)
{
T
re
=
x
.
real
(),
im
=
x
.
imag
();
T
r
=
rsqrt
(
re
*
re
+
im
*
im
);
T
ar2
=
re
*
r
*
r
;
const
T
root2
=
0.7071067811865475
;
T
root
[
2
];
// With float, calculating 1+re*r and 1-re*r may result in excessive errors
// due to subtraction of two close values. We have to get fancy
root
[
0
]
=
sqrt
(
r
*
((
std
::
is_same
<
T
,
float
>::
value
&&
re
*
r
<
-
0.98
)
?
rsqrt_helper
(
im
*
im
*
r
*
r
)
:
1
+
re
*
r
))
*
root2
;
root
[
1
]
=
sqrt
(
r
*
((
std
::
is_same
<
T
,
float
>::
value
&&
re
*
r
>
0.98
)
?
rsqrt_helper
(
im
*
im
*
r
*
r
)
:
1
-
re
*
r
))
*
root2
*
(
im
>=
0
?
-
1.
:
1.
);
return
*
(
reinterpret_cast
<
std
::
complex
<
T
>*>
(
&
root
));
return
conj
(
impl_sqrt
(
x
))
*
impl_rsqrt
(
norm
(
x
));
}
template
<
typename
T
>
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help