add bfloat16 support to some GPU ops: concat, constant, fill, pack, reshape,
slice, split, unpack PiperOrigin-RevId: 179255814
This commit is contained in:
parent
4f4abcaced
commit
dcb0666a2b
@ -116,8 +116,8 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_complex64(REGISTER);
|
||||
TF_CALL_complex128(REGISTER);
|
||||
TF_CALL_int64(REGISTER);
|
||||
REGISTER(bfloat16);
|
||||
REGISTER(bool);
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_bool(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -250,6 +250,7 @@ REGISTER_KERNEL_BUILDER(Name("Fill")
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL(GPU, Eigen::half);
|
||||
REGISTER_KERNEL(GPU, bfloat16);
|
||||
REGISTER_KERNEL(GPU, float);
|
||||
REGISTER_KERNEL(GPU, double);
|
||||
REGISTER_KERNEL(GPU, uint8);
|
||||
@ -328,6 +329,7 @@ REGISTER_KERNEL_BUILDER(Name("ZerosLike")
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL(bool, GPU);
|
||||
REGISTER_KERNEL(Eigen::half, GPU);
|
||||
REGISTER_KERNEL(bfloat16, GPU);
|
||||
REGISTER_KERNEL(float, GPU);
|
||||
REGISTER_KERNEL(double, GPU);
|
||||
REGISTER_KERNEL(complex64, GPU);
|
||||
@ -380,6 +382,7 @@ REGISTER_KERNEL_BUILDER(Name("OnesLike")
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL(bool, GPU);
|
||||
REGISTER_KERNEL(Eigen::half, GPU);
|
||||
REGISTER_KERNEL(bfloat16, GPU);
|
||||
REGISTER_KERNEL(float, GPU);
|
||||
REGISTER_KERNEL(double, GPU);
|
||||
REGISTER_KERNEL(complex64, GPU);
|
||||
|
@ -77,7 +77,8 @@ struct FillFunctor<GPUDevice, T> {
|
||||
|
||||
#define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>;
|
||||
TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU);
|
||||
DEFINE_FILL_GPU(bool);
|
||||
TF_CALL_bfloat16(DEFINE_FILL_GPU);
|
||||
TF_CALL_bool(DEFINE_FILL_GPU);
|
||||
#undef DEFINE_FILL_GPU
|
||||
|
||||
// Partial specialization of FillFunctor<Device=GPUDevice, T>.
|
||||
@ -88,15 +89,10 @@ struct SetZeroFunctor<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>
|
||||
DEFINE_SETZERO_GPU(bool);
|
||||
DEFINE_SETZERO_GPU(Eigen::half);
|
||||
DEFINE_SETZERO_GPU(float);
|
||||
DEFINE_SETZERO_GPU(double);
|
||||
DEFINE_SETZERO_GPU(complex64);
|
||||
DEFINE_SETZERO_GPU(complex128);
|
||||
DEFINE_SETZERO_GPU(int32);
|
||||
DEFINE_SETZERO_GPU(int64);
|
||||
#define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>;
|
||||
TF_CALL_NUMBER_TYPES(DEFINE_SETZERO_GPU);
|
||||
TF_CALL_bfloat16(DEFINE_SETZERO_GPU);
|
||||
TF_CALL_bool(DEFINE_SETZERO_GPU);
|
||||
#undef DEFINE_SETZERO_GPU
|
||||
|
||||
// Partial specialization of FillFunctor<Device=GPUDevice, T>.
|
||||
@ -107,15 +103,10 @@ struct SetOneFunctor<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_SETONE_GPU(T) template struct SetOneFunctor<GPUDevice, T>
|
||||
DEFINE_SETONE_GPU(bool);
|
||||
DEFINE_SETONE_GPU(Eigen::half);
|
||||
DEFINE_SETONE_GPU(float);
|
||||
DEFINE_SETONE_GPU(double);
|
||||
DEFINE_SETONE_GPU(complex64);
|
||||
DEFINE_SETONE_GPU(complex128);
|
||||
DEFINE_SETONE_GPU(int32);
|
||||
DEFINE_SETONE_GPU(int64);
|
||||
#define DEFINE_SETONE_GPU(T) template struct SetOneFunctor<GPUDevice, T>;
|
||||
TF_CALL_NUMBER_TYPES(DEFINE_SETONE_GPU);
|
||||
TF_CALL_bfloat16(DEFINE_SETONE_GPU);
|
||||
TF_CALL_bool(DEFINE_SETONE_GPU);
|
||||
#undef DEFINE_SETONE_GPU
|
||||
|
||||
} // end namespace functor
|
||||
|
@ -158,6 +158,7 @@ REGISTER_PACK(string);
|
||||
PackOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
REGISTER_GPU(bool);
|
||||
#undef REGISTER_GPU
|
||||
|
@ -43,7 +43,8 @@ REGISTER_KERNEL_BUILDER(Name("Reshape")
|
||||
.TypeConstraint<int64>("Tshape"), \
|
||||
ReshapeOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
REGISTER_GPU_KERNEL(bool);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_bool(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
@ -439,7 +439,7 @@ namespace functor {
|
||||
DECLARE_CPU_SPEC(T, 7);
|
||||
|
||||
TF_CALL_ALL_TYPES(DECLARE_FOR_N);
|
||||
DECLARE_FOR_N(bfloat16);
|
||||
TF_CALL_bfloat16(DECLARE_FOR_N);
|
||||
|
||||
#undef DECLARE_FOR_N
|
||||
#undef DECLARE_CPU_SPEC
|
||||
@ -456,7 +456,7 @@ DECLARE_FOR_N(bfloat16);
|
||||
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
|
||||
REGISTER_SLICE(bfloat16);
|
||||
TF_CALL_bfloat16(REGISTER_SLICE);
|
||||
#undef REGISTER_SLICE
|
||||
#else
|
||||
#define REGISTER_SLICE(type) \
|
||||
@ -469,7 +469,7 @@ REGISTER_SLICE(bfloat16);
|
||||
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
|
||||
REGISTER_SLICE(bfloat16);
|
||||
TF_CALL_bfloat16(REGISTER_SLICE);
|
||||
#undef REGISTER_SLICE
|
||||
#endif // INTEL_MKL
|
||||
|
||||
@ -497,6 +497,7 @@ namespace functor {
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
|
||||
TF_CALL_complex64(DECLARE_FOR_N);
|
||||
TF_CALL_complex128(DECLARE_FOR_N);
|
||||
TF_CALL_bfloat16(DECLARE_FOR_N);
|
||||
DECLARE_FOR_N(int32);
|
||||
|
||||
#undef DECLARE_FOR_N
|
||||
@ -515,6 +516,7 @@ DECLARE_FOR_N(int32);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -39,6 +39,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
|
||||
DEFINE_GPU_KERNELS(int32);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
@ -52,7 +52,7 @@ void SplitCustom<Device, T>::operator()(
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
DEFINE_GPU_KERNELS(bfloat16);
|
||||
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
#define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>;
|
||||
@ -60,7 +60,7 @@ DEFINE_GPU_KERNELS(bfloat16);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
DEFINE_GPU_KERNELS(bfloat16);
|
||||
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
||||
@ -243,6 +243,7 @@ struct SplitVOpGPULaunch {
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
template struct SplitVOpGPULaunch<T, int32>; \
|
||||
@ -251,7 +252,7 @@ TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||
REGISTER_GPU_KERNEL(bfloat16);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -377,6 +377,7 @@ REGISTER_SPLIT(bfloat16);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -142,6 +142,7 @@ TF_CALL_ALL_TYPES(REGISTER_UNPACK);
|
||||
UnpackOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -374,6 +374,20 @@ __device__ __host__ inline Eigen::half ldg(const Eigen::half* address) {
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __host__ inline tensorflow::bfloat16 ldg(
|
||||
const tensorflow::bfloat16* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
tensorflow::bfloat16 return_value;
|
||||
asm volatile("ld.global.nc.u16 %0, [%1];"
|
||||
: "=h"(return_value.value)
|
||||
: "l"(address));
|
||||
return return_value;
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __host__ inline bool ldg(const bool* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
|
Loading…
x
Reference in New Issue
Block a user