Merge various kernel registrations with macros
We add the TF_CALL_COMPLEX_TYPES macro and update related kernel registrations with more compact macros rather than the individual dtype listings. This should be a no-op and should give better visibility into what is the dtype coverage for many of our kernels. PiperOrigin-RevId: 315224662 Change-Id: I14aad07711a407fa632a94d891238a48ae89bcab
This commit is contained in:
parent
9b37e09994
commit
4e7ce793d9
|
@ -179,13 +179,14 @@ limitations under the License.
|
|||
TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
|
||||
TF_CALL_int8(m)
|
||||
|
||||
// Call "m" for all number types, including complex64 and complex128.
|
||||
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||
|
||||
// Call "m" for all number types, including complex types
|
||||
#define TF_CALL_NUMBER_TYPES(m) \
|
||||
TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||
TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_COMPLEX_TYPES(m)
|
||||
|
||||
#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
|
||||
TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
||||
TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||
TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) TF_CALL_COMPLEX_TYPES(m)
|
||||
|
||||
#define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m)
|
||||
|
||||
|
@ -202,8 +203,7 @@ limitations under the License.
|
|||
|
||||
// Call "m" on all types supported on GPU.
|
||||
#define TF_CALL_GPU_ALL_TYPES(m) \
|
||||
TF_CALL_GPU_NUMBER_TYPES(m) \
|
||||
TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||
TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_COMPLEX_TYPES(m) TF_CALL_bool(m)
|
||||
|
||||
#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)
|
||||
|
||||
|
@ -213,11 +213,10 @@ limitations under the License.
|
|||
TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m)
|
||||
|
||||
// Types used for save and restore ops.
|
||||
#define TF_CALL_SAVE_RESTORE_TYPES(m) \
|
||||
TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \
|
||||
TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_tstring(m) \
|
||||
TF_CALL_QUANTIZED_TYPES(m)
|
||||
#define TF_CALL_SAVE_RESTORE_TYPES(m) \
|
||||
TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
|
||||
TF_CALL_COMPLEX_TYPES(m) \
|
||||
TF_CALL_QUANTIZED_TYPES(m) TF_CALL_bool(m) TF_CALL_tstring(m)
|
||||
|
||||
#ifdef TENSORFLOW_SYCL_NO_DOUBLE
|
||||
#define TF_CALL_SYCL_double(m)
|
||||
|
|
|
@ -48,11 +48,10 @@ REGISTER_ADDN_CPU(Variant);
|
|||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
#define REGISTER_ADDN_GPU(type) REGISTER_ADDN(type, GPU)
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ADDN_GPU);
|
||||
TF_CALL_int64(REGISTER_ADDN_GPU);
|
||||
TF_CALL_complex64(REGISTER_ADDN_GPU);
|
||||
TF_CALL_complex128(REGISTER_ADDN_GPU);
|
||||
TF_CALL_variant(REGISTER_ADDN_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ADDN_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_ADDN_GPU);
|
||||
#undef REGISTER_ADDN_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
|
|
@ -154,10 +154,9 @@ struct Add9Functor<GPUDevice, T> {
|
|||
template struct functor::Add8pFunctor<GPUDevice, type>; \
|
||||
template struct functor::Add9Functor<GPUDevice, type>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS);
|
||||
TF_CALL_int64(REGISTER_FUNCTORS);
|
||||
TF_CALL_complex64(REGISTER_FUNCTORS);
|
||||
TF_CALL_complex128(REGISTER_FUNCTORS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_FUNCTORS);
|
||||
|
||||
#undef REGISTER_FUNCTORS
|
||||
|
||||
|
|
|
@ -17,12 +17,10 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -87,8 +87,7 @@ namespace functor {
|
|||
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPEC);
|
||||
} // namespace functor
|
||||
|
||||
template <class Scalar>
|
||||
|
|
|
@ -64,15 +64,12 @@ void ConcatGPU(
|
|||
inputs_flat, \
|
||||
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_complex64(REGISTER);
|
||||
TF_CALL_complex128(REGISTER);
|
||||
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER);
|
||||
TF_CALL_int16(REGISTER);
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_bool(REGISTER);
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
|
|
@ -98,15 +98,12 @@ void ConcatGPU(
|
|||
inputs_flat, \
|
||||
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_complex64(REGISTER);
|
||||
TF_CALL_complex128(REGISTER);
|
||||
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER);
|
||||
TF_CALL_int16(REGISTER);
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_bool(REGISTER);
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
|
|
@ -66,15 +66,12 @@ void ConcatGPUImpl(const Eigen::GpuDevice& d,
|
|||
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_complex64(REGISTER);
|
||||
TF_CALL_complex128(REGISTER);
|
||||
TF_CALL_int32(REGISTER); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER);
|
||||
TF_CALL_int16(REGISTER);
|
||||
TF_CALL_bfloat16(REGISTER);
|
||||
TF_CALL_bool(REGISTER);
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -201,45 +201,33 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
|
|||
const GpuDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
|
||||
int split_size, typename TTypes<T, 2>::Matrix* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_complex64(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_complex128(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_int32(REGISTER_GPUCONCAT32); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_int16(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_uint8(REGISTER_GPUCONCAT32);
|
||||
REGISTER_GPUCONCAT32(bfloat16);
|
||||
REGISTER_GPUCONCAT32(bool);
|
||||
TF_CALL_bfloat16(REGISTER_GPUCONCAT32);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT32);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_complex64(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_complex128(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_int32(REGISTER_GPUCONCAT64); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_int16(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_uint8(REGISTER_GPUCONCAT64);
|
||||
REGISTER_GPUCONCAT64(bfloat16);
|
||||
REGISTER_GPUCONCAT64(bool);
|
||||
TF_CALL_bfloat16(REGISTER_GPUCONCAT64);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPUCONCAT64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
|
||||
TF_CALL_complex64(REGISTER_GPU32);
|
||||
TF_CALL_complex128(REGISTER_GPU32);
|
||||
TF_CALL_int32(REGISTER_GPU32); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPU32);
|
||||
TF_CALL_int16(REGISTER_GPU32);
|
||||
TF_CALL_uint8(REGISTER_GPU32);
|
||||
REGISTER_GPU32(bfloat16);
|
||||
REGISTER_GPU32(bool);
|
||||
TF_CALL_bfloat16(REGISTER_GPU32);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU32);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
|
||||
TF_CALL_complex64(REGISTER_GPU64);
|
||||
TF_CALL_complex128(REGISTER_GPU64);
|
||||
TF_CALL_int32(REGISTER_GPU64); // Needed for TensorLists.
|
||||
TF_CALL_int64(REGISTER_GPU64);
|
||||
TF_CALL_int16(REGISTER_GPU64);
|
||||
TF_CALL_uint8(REGISTER_GPU64);
|
||||
REGISTER_GPU64(bfloat16);
|
||||
REGISTER_GPU64(bool);
|
||||
TF_CALL_bfloat16(REGISTER_GPU64);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU64);
|
||||
|
||||
#undef REGISTER_GPUCONCAT32
|
||||
#undef REGISTER_GPUCONCAT64
|
||||
|
|
|
@ -227,13 +227,10 @@ REGISTER_CONCAT(uint64);
|
|||
.HostMemory("axis"), \
|
||||
ConcatV2Op<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_uint8(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
REGISTER_GPU(bool);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
|
|
@ -176,8 +176,7 @@ TF_CALL_double(REGISTER_DIAGOP);
|
|||
TF_CALL_float(REGISTER_DIAGOP);
|
||||
TF_CALL_int32(REGISTER_DIAGOP);
|
||||
TF_CALL_int64(REGISTER_DIAGOP);
|
||||
TF_CALL_complex64(REGISTER_DIAGOP);
|
||||
TF_CALL_complex128(REGISTER_DIAGOP);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_DIAGOP);
|
||||
TF_CALL_half(REGISTER_DIAGOP);
|
||||
#undef REGISTER_DIAGOP
|
||||
|
||||
|
@ -190,8 +189,7 @@ TF_CALL_double(REGISTER_DIAGPARTOP);
|
|||
TF_CALL_float(REGISTER_DIAGPARTOP);
|
||||
TF_CALL_int32(REGISTER_DIAGPARTOP);
|
||||
TF_CALL_int64(REGISTER_DIAGPARTOP);
|
||||
TF_CALL_complex64(REGISTER_DIAGPARTOP);
|
||||
TF_CALL_complex128(REGISTER_DIAGPARTOP);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_DIAGPARTOP);
|
||||
TF_CALL_half(REGISTER_DIAGPARTOP);
|
||||
#undef REGISTER_DIAGPARTOP
|
||||
|
||||
|
@ -217,8 +215,7 @@ TF_CALL_double(REGISTER_DIAGOP_GPU);
|
|||
TF_CALL_float(REGISTER_DIAGOP_GPU);
|
||||
TF_CALL_int32(REGISTER_DIAGOP_GPU);
|
||||
TF_CALL_int64(REGISTER_DIAGOP_GPU);
|
||||
TF_CALL_complex64(REGISTER_DIAGOP_GPU);
|
||||
TF_CALL_complex128(REGISTER_DIAGOP_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_DIAGOP_GPU);
|
||||
TF_CALL_half(REGISTER_DIAGOP_GPU);
|
||||
#undef REGISTER_DIAGOP_GPU
|
||||
|
||||
|
@ -242,8 +239,7 @@ TF_CALL_double(REGISTER_DIAGPARTOP_GPU);
|
|||
TF_CALL_float(REGISTER_DIAGPARTOP_GPU);
|
||||
TF_CALL_int32(REGISTER_DIAGPARTOP_GPU);
|
||||
TF_CALL_int64(REGISTER_DIAGPARTOP_GPU);
|
||||
TF_CALL_complex64(REGISTER_DIAGPARTOP_GPU);
|
||||
TF_CALL_complex128(REGISTER_DIAGPARTOP_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_DIAGPARTOP_GPU);
|
||||
TF_CALL_half(REGISTER_DIAGPARTOP_GPU);
|
||||
#undef REGISTER_DIAGPARTOP_GPU
|
||||
|
||||
|
|
|
@ -467,8 +467,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
|
|||
DynamicPartitionOpGPU<T>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_PARTITION_GPU);
|
||||
TF_CALL_complex64(REGISTER_DYNAMIC_PARTITION_GPU);
|
||||
TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_DYNAMIC_PARTITION_GPU);
|
||||
#undef REGISTER_DYNAMIC_PARTITION_GPU
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -147,11 +147,11 @@ void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
|||
const int32 first_dim_size, \
|
||||
const GpuDeviceArrayStruct<int32>& input_indices, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
|
||||
TF_CALL_int32(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
template <class T>
|
||||
|
@ -357,11 +357,10 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_DYNAMIC_STITCH);
|
|||
.HostMemory("merged"), \
|
||||
ParallelDynamicStitchOpCPU<type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
#undef REGISTER_DYNAMIC_STITCH_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -70,11 +70,10 @@ void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
|||
const GpuDeviceArrayStruct<int32>& input_indices, \
|
||||
const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int32(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_int32(REGISTER_GPU)
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
|
|
@ -33,11 +33,8 @@ namespace tensorflow {
|
|||
DECLARE_GPU_SPECS_NDIM(T, 5); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 6);
|
||||
|
||||
TF_CALL_half(DECLARE_GPU_SPECS);
|
||||
TF_CALL_float(DECLARE_GPU_SPECS);
|
||||
TF_CALL_double(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS_NDIM
|
||||
#undef DECLARE_GPU_SPECS
|
||||
|
|
|
@ -39,8 +39,7 @@ namespace functor {
|
|||
|
||||
TF_CALL_int64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
|
|
|
@ -39,8 +39,7 @@ namespace functor {
|
|||
|
||||
TF_CALL_int64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
|
|
|
@ -31,12 +31,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
TF_CALL_bool(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int32(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
#undef DEFINE_GPU_SPECS_INDEX
|
||||
|
|
|
@ -31,12 +31,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
TF_CALL_bool(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int32(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
#undef DEFINE_GPU_SPECS_INDEX
|
||||
|
|
|
@ -105,8 +105,7 @@ namespace functor {
|
|||
TF_CALL_int32(DECLARE_GPU_SPECS);
|
||||
TF_CALL_int64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
|
@ -118,8 +117,7 @@ TF_CALL_complex128(DECLARE_GPU_SPECS);
|
|||
TF_CALL_int32(REGISTER_GATHER_ND_GPU);
|
||||
TF_CALL_int64(REGISTER_GATHER_ND_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
|
||||
TF_CALL_complex64(REGISTER_GATHER_ND_GPU);
|
||||
TF_CALL_complex128(REGISTER_GATHER_ND_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GATHER_ND_GPU);
|
||||
|
||||
#undef REGISTER_GATHER_ND_GPU
|
||||
|
||||
|
|
|
@ -121,8 +121,7 @@ struct GatherNdSlice<GPUDevice, T, Index, IXDIM> {
|
|||
TF_CALL_int32(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
#undef DEFINE_GPU_SPECS_INDEX
|
||||
|
|
|
@ -221,12 +221,9 @@ TF_CALL_uint64(REGISTER_GATHER_CPU);
|
|||
// Registration of the GPU implementations.
|
||||
#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
|
||||
|
||||
TF_CALL_bool(REGISTER_GATHER_GPU);
|
||||
TF_CALL_int32(REGISTER_GATHER_GPU);
|
||||
TF_CALL_int64(REGISTER_GATHER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GATHER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GATHER_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
|
||||
|
||||
#undef REGISTER_GATHER_GPU
|
||||
|
||||
|
|
|
@ -105,13 +105,10 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
.HostMemory("lengths"), \
|
||||
TensorListSplit<GPUDevice, T>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
REGISTER_TENSOR_LIST_OPS_GPU(bfloat16);
|
||||
TF_CALL_complex64(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
TF_CALL_complex128(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
TF_CALL_int32(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
TF_CALL_int64(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
REGISTER_TENSOR_LIST_OPS_GPU(bool);
|
||||
TF_CALL_bfloat16(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_OPS_GPU);
|
||||
|
||||
#undef REGISTER_TENSOR_LIST_OPS_GPU
|
||||
|
||||
|
|
|
@ -581,21 +581,14 @@ struct MatMulFunctor<SYCLDevice, T> {
|
|||
.Label("cublas"), \
|
||||
MatMulOp<GPUDevice, T, true /* cublas */>)
|
||||
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_half(REGISTER_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
TF_CALL_int32(REGISTER_CPU);
|
||||
TF_CALL_int64(REGISTER_CPU);
|
||||
TF_CALL_complex64(REGISTER_CPU);
|
||||
TF_CALL_complex128(REGISTER_CPU);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
|
|
@ -210,10 +210,7 @@ namespace functor {
|
|||
}; \
|
||||
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
|
@ -225,10 +222,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
.HostMemory("num_lower") \
|
||||
.HostMemory("num_upper"), \
|
||||
MatrixBandPartOp<GPUDevice, type>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_bool(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
#undef REGISTER_MATRIX_BAND_PART_GPU
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
|
@ -68,10 +68,7 @@ struct MatrixBandPartFunctor<GPUDevice, Scalar> {
|
|||
|
||||
#define DEFINE_GPU_SPEC(T) template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPEC);
|
||||
|
||||
#undef DEFINE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
|
|
@ -469,10 +469,7 @@ namespace functor {
|
|||
const bool left_align_subdiagonal); \
|
||||
extern template struct MatrixDiagPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
|
@ -513,10 +510,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
.HostMemory("padding_value"), \
|
||||
MatrixDiagPartOp<GPUDevice, type>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_MATRIX_DIAG_GPU);
|
||||
#undef REGISTER_MATRIX_DIAG_GPU
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
|
@ -163,10 +163,7 @@ struct MatrixDiagPart<GPUDevice, T> {
|
|||
template struct MatrixDiag<GPUDevice, T>; \
|
||||
template struct MatrixDiagPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -272,10 +272,7 @@ namespace functor {
|
|||
const bool left_align_superdiagonal, const bool left_align_subdiagonal); \
|
||||
extern template struct MatrixSetDiag<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
|
@ -295,10 +292,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
.HostMemory("k"), \
|
||||
MatrixSetDiagOp<GPUDevice, type>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
#undef REGISTER_MATRIX_SET_DIAG_GPU
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
|
@ -136,10 +136,7 @@ struct MatrixSetDiag<GPUDevice, Scalar> {
|
|||
|
||||
#define DEFINE_GPU_SPEC(T) template struct MatrixSetDiag<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPEC);
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -13,16 +13,15 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -316,13 +316,11 @@ class BatchMatMulMkl : public OpKernel {
|
|||
#ifdef ENABLE_MKL
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL);
|
||||
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
|
||||
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
|
||||
|
|
|
@ -160,12 +160,9 @@ namespace functor {
|
|||
DECLARE_GPU_SPEC_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPEC_INDEX(T, int64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_int32(DECLARE_GPU_SPEC);
|
||||
TF_CALL_int64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
|
||||
|
||||
#undef DECLARE_GPU_SPEC_INDEX
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
@ -186,12 +183,9 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
|
||||
REGISTER_ONE_HOT_GPU_INDEX(type, int64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
|
||||
TF_CALL_bool(REGISTER_ONE_HOT_GPU);
|
||||
TF_CALL_int32(REGISTER_ONE_HOT_GPU);
|
||||
TF_CALL_int64(REGISTER_ONE_HOT_GPU);
|
||||
TF_CALL_complex64(REGISTER_ONE_HOT_GPU);
|
||||
TF_CALL_complex128(REGISTER_ONE_HOT_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_ONE_HOT_GPU);
|
||||
|
||||
#undef REGISTER_ONE_HOT_GPU_INDEX
|
||||
#undef REGISTER_ONE_HOT_GPU
|
||||
|
|
|
@ -37,12 +37,9 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
DEFINE_GPU_SPEC_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPEC_INDEX(T, int64)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_int32(DEFINE_GPU_SPEC);
|
||||
TF_CALL_int64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPEC);
|
||||
|
||||
#undef DEFINE_GPU_SPEC_INDEX
|
||||
#undef DEFINE_GPU_SPEC
|
||||
|
|
|
@ -152,13 +152,10 @@ REGISTER_PACK(tstring);
|
|||
Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
PackOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_int16(REGISTER_GPU);
|
||||
TF_CALL_bool(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
|
|
@ -52,8 +52,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
|||
functor::EuclideanNormReducer<type>>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
#endif
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
|
|
|
@ -52,8 +52,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
|||
ReductionOp<GPUDevice, type, int64, functor::MeanReducer<type>>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
#endif
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
|
|
|
@ -50,11 +50,10 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
|||
.HostMemory("reduction_indices"), \
|
||||
ReductionOp<GPUDevice, type, int64, \
|
||||
Eigen::internal::ProdReducer<type>>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_int32(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
#endif
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
|
|
|
@ -50,11 +50,10 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
|||
.TypeConstraint<int64>("Tidx") \
|
||||
.HostMemory("reduction_indices"), \
|
||||
ReductionOp<GPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
#endif
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
|
|
|
@ -342,12 +342,7 @@ namespace functor {
|
|||
|
||||
TF_CALL_uint8(DECLARE_GPU_SPEC);
|
||||
TF_CALL_int8(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_half(DECLARE_GPU_SPEC);
|
||||
TF_CALL_float(DECLARE_GPU_SPEC);
|
||||
TF_CALL_double(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
#undef DECLARE_GPU_SPEC_DIM
|
||||
} // namespace functor
|
||||
|
@ -373,12 +368,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
ReverseV2Op<GPUDevice, T, int64>)
|
||||
TF_CALL_uint8(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_int8(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_bool(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_half(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
|
|
@ -40,12 +40,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
|
||||
TF_CALL_uint8(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_int8(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_bool(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_half(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_float(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_double(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_complex64(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_complex128(DEFINE_REVERSE_ALL_DIMS);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_REVERSE_ALL_DIMS);
|
||||
#undef DEFINE_REVERSE
|
||||
#undef DEFINE_REVERSE_ALL_DIMS
|
||||
|
||||
|
|
|
@ -397,10 +397,9 @@ TF_CALL_ALL_TYPES(REGISTER_CPU);
|
|||
|
||||
TF_CALL_int32(REGISTER_KERNEL);
|
||||
TF_CALL_int64(REGISTER_KERNEL);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_KERNEL);
|
||||
TF_CALL_uint32(REGISTER_KERNEL);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_KERNEL);
|
||||
|
||||
#undef REGISTER_KERNEL
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -93,10 +93,9 @@ struct Roll<GPUDevice, T> {
|
|||
|
||||
TF_CALL_int32(DEFINE_GPU_SPECS);
|
||||
TF_CALL_int64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_uint32(DEFINE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPECS);
|
||||
TF_CALL_uint32(DEFINE_GPU_SPECS)
|
||||
TF_CALL_COMPLEX_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
} // namespace functor
|
||||
|
|
|
@ -502,8 +502,7 @@ TF_CALL_int64(REGISTER_SCATTER_ND_ALL_GPU);
|
|||
TF_CALL_int64(REGISTER_SCATTER_ND_MIN_MAX_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX_GPU);
|
||||
TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_ALL_GPU);
|
||||
|
||||
#undef REGISTER_SCATTER_ND_ALL_GPU
|
||||
|
||||
|
@ -563,8 +562,7 @@ TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU);
|
|||
TF_CALL_int64(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_TENSOR_GPU_MIN_MAX);
|
||||
TF_CALL_complex64(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
TF_CALL_complex128(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_TENSOR_GPU);
|
||||
|
||||
#undef REGISTER_SCATTER_ND_ADD
|
||||
#undef REGISTER_SCATTER_ND_ADD_SUB
|
||||
|
@ -862,8 +860,7 @@ TF_CALL_int32(DECLARE_GPU_SPECS);
|
|||
TF_CALL_int32(DECLARE_GPU_SPECS_MIN_MAX);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MIN_MAX);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS_MIN_MAX
|
||||
#undef DECLARE_GPU_SPECS
|
||||
|
|
|
@ -200,8 +200,7 @@ TF_CALL_int64(DECLARE_GPU_SPECS);
|
|||
TF_CALL_int64(DECLARE_GPU_SPECS_MINMAX);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MINMAX);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_MINMAX
|
||||
|
|
|
@ -244,8 +244,7 @@ TF_CALL_int32(DEFINE_SUM_GPU_SPECS);
|
|||
|
||||
// TODO(rocm): support atomicAdd for complex numbers on ROCm
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(DEFINE_SUM_GPU_SPECS);
|
||||
TF_CALL_complex128(DEFINE_SUM_GPU_SPECS);
|
||||
TF_CALL_COMPLEX_TYPES(DEFINE_SUM_GPU_SPECS);
|
||||
#endif
|
||||
|
||||
#undef DEFINE_SORTED_GPU_SPECS_INDEX
|
||||
|
|
|
@ -113,8 +113,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
|||
TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
// TODO(rocm): support atomicAdd for complex numbers on ROCm
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
|
||||
|
|
|
@ -113,8 +113,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
|||
TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
// TODO(rocm): support atomicAdd for complex numbers on ROCm
|
||||
#if GOOGLE_CUDA
|
||||
TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
|
||||
|
|
|
@ -300,14 +300,11 @@ namespace functor {
|
|||
DECLARE_GPU_SPEC(T, 7); \
|
||||
DECLARE_GPU_SPEC(T, 8);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
|
||||
TF_CALL_complex64(DECLARE_FOR_N);
|
||||
TF_CALL_complex128(DECLARE_FOR_N);
|
||||
TF_CALL_bfloat16(DECLARE_FOR_N);
|
||||
TF_CALL_bool(DECLARE_FOR_N);
|
||||
TF_CALL_int8(DECLARE_FOR_N);
|
||||
TF_CALL_int32(DECLARE_FOR_N);
|
||||
TF_CALL_int64(DECLARE_FOR_N);
|
||||
DECLARE_FOR_N(int32);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N);
|
||||
|
||||
#undef DECLARE_FOR_N
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
@ -321,13 +318,10 @@ DECLARE_FOR_N(int32);
|
|||
.HostMemory("size"), \
|
||||
SliceOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_bool(REGISTER_GPU);
|
||||
TF_CALL_int8(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
|
|
@ -37,14 +37,11 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
template struct functor::Slice<GPUDevice, T, 7>; \
|
||||
template struct functor::Slice<GPUDevice, T, 8>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_bool(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int8(DEFINE_GPU_KERNELS);
|
||||
DEFINE_GPU_KERNELS(int32);
|
||||
DEFINE_GPU_KERNELS(int64);
|
||||
TF_CALL_int32(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
||||
|
|
|
@ -51,20 +51,16 @@ void SplitCustom<Device, T>::operator()(
|
|||
template struct Split<Eigen::GpuDevice, T, 2>; \
|
||||
template struct Split<Eigen::GpuDevice, T, 3>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_int64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_uint8(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_bool(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_KERNELS);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
#define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
@ -248,8 +244,7 @@ void SplitVOpGPULaunch<T, IntType>::Run(
|
|||
#define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
|
@ -257,8 +252,7 @@ TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
|||
template struct SplitVOpGPULaunch<T, int64>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
|
|
@ -50,12 +50,9 @@ struct SplitVOpGPULaunch {
|
|||
extern template struct SplitVOpGPULaunch<T, int32>; \
|
||||
extern template struct SplitVOpGPULaunch<T, int64>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_uint8(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_bool(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -417,10 +417,9 @@ REGISTER_SPLIT(uint64);
|
|||
.HostMemory("split_dim"), \
|
||||
SplitOpGPU<type>)
|
||||
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -471,10 +471,9 @@ TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN);
|
|||
REGISTER_GPU(type, int32); \
|
||||
REGISTER_GPU(type, int64);
|
||||
|
||||
TF_CALL_bfloat16(REGISTER_GPU_LEN);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN);
|
||||
TF_CALL_complex64(REGISTER_GPU_LEN);
|
||||
TF_CALL_complex128(REGISTER_GPU_LEN);
|
||||
REGISTER_GPU_LEN(bfloat16);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_LEN);
|
||||
#undef REGISTER_GPU_LEN
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
|
|
@ -486,12 +486,9 @@ TF_CALL_uint64(REGISTER_STRIDED_SLICE);
|
|||
.HostMemory("strides"), \
|
||||
StridedSliceAssignOp<GPUDevice, type, true>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_bool(REGISTER_GPU);
|
||||
TF_CALL_int8(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
|
|
@ -21,8 +21,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/kernels/strided_slice_op_gpu_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
TF_CALL_complex64(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_complex128(DEFINE_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(DEFINE_GPU_KERNELS);
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -278,16 +278,12 @@ class HandleStridedSliceAssignCase<Device, T, 0> {
|
|||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
|
||||
TF_CALL_complex64(PREVENT_FOR_N_GPU);
|
||||
TF_CALL_complex128(PREVENT_FOR_N_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(PREVENT_FOR_N_GPU);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
|
||||
TF_CALL_complex64(DECLARE_FOR_N_GPU);
|
||||
TF_CALL_complex128(DECLARE_FOR_N_GPU);
|
||||
TF_CALL_bool(DECLARE_FOR_N_GPU);
|
||||
TF_CALL_int8(DECLARE_FOR_N_GPU);
|
||||
DECLARE_FOR_N_GPU(int32);
|
||||
DECLARE_FOR_N_GPU(int64);
|
||||
TF_CALL_int32(DECLARE_FOR_N_GPU);
|
||||
TF_CALL_int64(DECLARE_FOR_N_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
|
||||
#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
|
||||
|
|
|
@ -46,8 +46,7 @@ TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
|
|||
|
||||
#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
|
||||
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
TF_CALL_complex64(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -71,8 +70,7 @@ TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
|
|||
|
||||
#define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
|
||||
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
TF_CALL_complex64(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
#undef TENSOR_ARRAY_SET_ZERO_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -61,8 +61,7 @@ TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
|
|||
|
||||
#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
|
||||
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
TF_CALL_complex64(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
|
||||
#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -89,8 +88,7 @@ TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
|
|||
|
||||
#define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
|
||||
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
TF_CALL_complex64(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
|
||||
#undef TENSOR_ARRAY_SET_ZERO_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -256,11 +256,10 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
|
|||
.HostMemory("handle"), \
|
||||
TensorArrayOp);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -483,10 +482,9 @@ TF_CALL_ALL_TYPES(REGISTER_WRITE);
|
|||
.HostMemory("index"), \
|
||||
TensorArrayWriteOp<GPUDevice, type>);
|
||||
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -572,11 +570,10 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
|
|||
.HostMemory("index"), \
|
||||
TensorArrayReadOp<GPUDevice, type>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -774,10 +771,9 @@ REGISTER_GATHER_AND_PACK(qint32);
|
|||
.HostMemory("handle"), \
|
||||
TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>);
|
||||
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -995,10 +991,9 @@ REGISTER_CONCAT(qint32);
|
|||
.HostMemory("handle"), \
|
||||
TensorArrayConcatOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
REGISTER_GPU(bfloat16);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
@ -1215,10 +1210,9 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
|
|||
TensorArrayUnpackOrScatterOp<GPUDevice, type, \
|
||||
false /* LEGACY_UNPACK */>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -1387,8 +1381,7 @@ TF_CALL_ALL_TYPES(REGISTER_SPLIT);
|
|||
TensorArraySplitOp<GPUDevice, type>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -727,12 +727,8 @@ class ApplyGradientDescentOp<SYCLDevice, T> : public OpKernel {
|
|||
ApplyGradientDescentOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
@ -898,12 +894,8 @@ class ApplyAdadeltaOp : public OpKernel {
|
|||
ApplyAdadeltaOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
@ -1089,12 +1081,8 @@ class SparseApplyAdadeltaOp : public OpKernel {
|
|||
REGISTER_KERNELS(T, int32); \
|
||||
REGISTER_KERNELS(T, int64);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
@ -1383,12 +1371,8 @@ class ApplyAdagradOp : public OpKernel {
|
|||
ApplyAdagradOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
@ -1492,12 +1476,8 @@ class ApplyAdagradV2Op : public OpKernel {
|
|||
ApplyAdagradV2Op<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
@ -1801,12 +1781,8 @@ class SparseApplyAdagradOp : public OpKernel {
|
|||
REGISTER_KERNELS(T, int32); \
|
||||
REGISTER_KERNELS(T, int64);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
@ -1976,12 +1952,8 @@ class SparseApplyAdagradV2Op : public OpKernel {
|
|||
REGISTER_KERNELS(T, int32); \
|
||||
REGISTER_KERNELS(T, int64);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
@ -3054,12 +3026,8 @@ class ApplyMomentumOp : public OpKernel {
|
|||
ApplyMomentumOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
@ -3209,12 +3177,8 @@ class SparseApplyMomentumOp : public OpKernel {
|
|||
REGISTER_KERNELS(T, int32); \
|
||||
REGISTER_KERNELS(T, int64);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
#undef REGISTER_KERNELS
|
||||
|
@ -3288,12 +3252,8 @@ class ApplyKerasMomentumOp : public OpKernel {
|
|||
ApplyKerasMomentumOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
@ -3423,12 +3383,9 @@ class SparseApplyKerasMomentumOp : public OpKernel {
|
|||
REGISTER_KERNELS(T, CPU, int32); \
|
||||
REGISTER_KERNELS(T, CPU, int64);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -3699,12 +3656,8 @@ class ApplyAdamOp<SYCLDevice, T> : public OpKernel {
|
|||
ApplyAdamOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
|
||||
|
@ -4226,12 +4179,8 @@ class ApplyCenteredRMSPropOp : public OpKernel {
|
|||
ApplyCenteredRMSPropOp<D##Device, T>);
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_double(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
|
|
|
@ -144,12 +144,9 @@ TF_CALL_ALL_TYPES(REGISTER_UNPACK);
|
|||
Name("Unpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
UnpackOp<GPUDevice, type>)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_bfloat16(REGISTER_GPU);
|
||||
TF_CALL_uint8(REGISTER_GPU);
|
||||
TF_CALL_bool(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
|
|
|
@ -250,11 +250,10 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
|||
.HostMemory("is_initialized"), \
|
||||
IsVariableInitializedOp);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_complex128(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
|
Loading…
Reference in New Issue