PR #43167: [INTEL MKL] Added missed bfloat16 CPU support for op math.rsqrt
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43167
Copybara import of the project:
--
17ca1c7430
by Xiaoming (Jason) Cui <xiaoming.cui@intel.com>:
[INTEL MKL] Added missed bfloat16 CPU support for op math.rsqrt
PiperOrigin-RevId: 332304599
Change-Id: I9c66d29ed7cf1010388d70f30f84fef97de6dde6
This commit is contained in:
parent
69662b03be
commit
5d80f5900e
@ -16,15 +16,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER6(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, bfloat16,
|
||||
double, complex64, complex128);
|
||||
REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
|
||||
complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
|
||||
#endif
|
||||
|
||||
REGISTER6(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
|
||||
Eigen::half, bfloat16, double, complex64, complex128);
|
||||
REGISTER5(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
|
||||
Eigen::half, double, complex64, complex128);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER3(SimpleBinaryOp, GPU, "RsqrtGrad", functor::rsqrt_grad, float,
|
||||
Eigen::half, double);
|
||||
|
@ -405,7 +405,6 @@ class UnaryOpTest(test.TestCase):
|
||||
self._compareCpu(z, compute_f32(np.log), math_ops.log)
|
||||
self._compareCpu(z, compute_f32(np.log1p), math_ops.log1p)
|
||||
self._compareCpu(y, np.sign, math_ops.sign)
|
||||
self._compareCpu(z, self._rsqrt, math_ops.rsqrt)
|
||||
self._compareBoth(x, compute_f32(np.sin), math_ops.sin)
|
||||
self._compareBoth(x, compute_f32(np.cos), math_ops.cos)
|
||||
self._compareBoth(x, compute_f32(np.tan), math_ops.tan)
|
||||
|
Loading…
Reference in New Issue
Block a user