Enable batch matmul for half types again which was accidentally disabled in CL 197137612.
PiperOrigin-RevId: 208022865
This commit is contained in:
parent
e2d9cd3fa7
commit
cfdc469565
@ -2854,6 +2854,8 @@ tf_kernel_library(
|
|||||||
srcs = [] + if_mkl([
|
srcs = [] + if_mkl([
|
||||||
"mkl_batch_matmul_op.cc",
|
"mkl_batch_matmul_op.cc",
|
||||||
]),
|
]),
|
||||||
|
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
|
||||||
|
hdrs = ["batch_matmul_op_impl.h"],
|
||||||
# Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
|
# Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
|
||||||
# to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
|
# to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
|
||||||
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
|
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
|
||||||
|
@ -31,8 +31,7 @@ TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
|
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
|
||||||
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
|
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
|
||||||
// TODO(csigg): Implement Stream::ThenBlasGemv for Eigen::half and uncomment.
|
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
|
||||||
// TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
Loading…
x
Reference in New Issue
Block a user