Add uint32 & uint64 to TF_CALL_INTEGRAL_TYPES
Both uint32 & uint64 had been omitted from TF_CALL_INTEGRAL_TYPES due to suggested concerns of size bloat. In reality it seems that the size increase is only around 2MB. Further, this fixes #39649 since we are no longer inadvertently using the XLA_CPU device to perform tf.reduce_mean. PiperOrigin-RevId: 317259372 Change-Id: Iacf75eaedce198fbef4bd9fd59b6fefa584cbf34
This commit is contained in:
		
							parent
							
								
									8e654afea4
								
							
						
					
					
						commit
						e972c55726
					
				@ -153,16 +153,9 @@ limitations under the License.
 | 
			
		||||
#endif  // defined(IS_MOBILE_PLATFORM)  - end of TF_CALL_type defines
 | 
			
		||||
 | 
			
		||||
// Defines for sets of types.
 | 
			
		||||
 | 
			
		||||
// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
 | 
			
		||||
//
 | 
			
		||||
// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
 | 
			
		||||
// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
 | 
			
		||||
// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
 | 
			
		||||
// TF binary size and performance.
 | 
			
		||||
#define TF_CALL_INTEGRAL_TYPES(m)                                      \
 | 
			
		||||
  TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
 | 
			
		||||
      TF_CALL_uint8(m) TF_CALL_int8(m)
 | 
			
		||||
#define TF_CALL_INTEGRAL_TYPES(m)                                       \
 | 
			
		||||
  TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
 | 
			
		||||
      TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
 | 
			
		||||
 | 
			
		||||
#define TF_CALL_FLOAT_TYPES(m) \
 | 
			
		||||
  TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
 | 
			
		||||
@ -174,10 +167,10 @@ limitations under the License.
 | 
			
		||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
 | 
			
		||||
  TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
 | 
			
		||||
 | 
			
		||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)                              \
 | 
			
		||||
  TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)   \
 | 
			
		||||
      TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
 | 
			
		||||
          TF_CALL_int8(m)
 | 
			
		||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)                                \
 | 
			
		||||
  TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)     \
 | 
			
		||||
      TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
 | 
			
		||||
          TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
 | 
			
		||||
 | 
			
		||||
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -238,11 +238,6 @@ int DataTypeSize(DataType dt) {
 | 
			
		||||
    TF_CALL_qint16(CASE);
 | 
			
		||||
    TF_CALL_quint16(CASE);
 | 
			
		||||
 | 
			
		||||
    // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
 | 
			
		||||
    // don't want to define kernels for them at this stage to avoid binary
 | 
			
		||||
    // bloat.
 | 
			
		||||
    TF_CALL_uint32(CASE);
 | 
			
		||||
    TF_CALL_uint64(CASE);
 | 
			
		||||
    default:
 | 
			
		||||
      return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -4900,7 +4900,9 @@ tf_kernel_library(
 | 
			
		||||
        "topk_op_gpu_double.cu.cc",
 | 
			
		||||
        "topk_op_gpu_float.cu.cc",
 | 
			
		||||
        "topk_op_gpu_half.cu.cc",
 | 
			
		||||
        "topk_op_gpu_uint64.cu.cc",
 | 
			
		||||
        "topk_op_gpu_int64.cu.cc",
 | 
			
		||||
        "topk_op_gpu_uint32.cu.cc",
 | 
			
		||||
        "topk_op_gpu_int32.cu.cc",
 | 
			
		||||
        "topk_op_gpu_int16.cu.cc",
 | 
			
		||||
        "topk_op_gpu_uint16.cu.cc",
 | 
			
		||||
 | 
			
		||||
@ -116,8 +116,6 @@ REGISTER(qint8)
 | 
			
		||||
REGISTER(quint16)
 | 
			
		||||
REGISTER(qint16)
 | 
			
		||||
REGISTER(qint32)
 | 
			
		||||
REGISTER(uint32)
 | 
			
		||||
REGISTER(uint64)
 | 
			
		||||
 | 
			
		||||
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
 | 
			
		||||
    !defined(__ANDROID_TYPES_FULL__)
 | 
			
		||||
 | 
			
		||||
@ -208,8 +208,6 @@ REGISTER_CONCAT(qint8);
 | 
			
		||||
REGISTER_CONCAT(quint16);
 | 
			
		||||
REGISTER_CONCAT(qint16);
 | 
			
		||||
REGISTER_CONCAT(qint32);
 | 
			
		||||
REGISTER_CONCAT(uint32);
 | 
			
		||||
REGISTER_CONCAT(uint64);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_CONCAT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -211,7 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
 | 
			
		||||
// the conversion from uint8 to quint8.
 | 
			
		||||
REGISTER_KERNEL(CPU, quint8);
 | 
			
		||||
REGISTER_KERNEL(CPU, quint16);
 | 
			
		||||
REGISTER_KERNEL(CPU, uint32);
 | 
			
		||||
#undef REGISTER_CPU_KERNEL
 | 
			
		||||
 | 
			
		||||
#ifdef TENSORFLOW_USE_SYCL
 | 
			
		||||
 | 
			
		||||
@ -101,16 +101,12 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
 | 
			
		||||
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
 | 
			
		||||
REGISTER_CPU_SWITCH(uint64);
 | 
			
		||||
 | 
			
		||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
 | 
			
		||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
 | 
			
		||||
REGISTER_GPU_SWITCH(uint64);
 | 
			
		||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
 | 
			
		||||
TF_CALL_uint32(REGISTER_GPU_SWITCH);
 | 
			
		||||
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_CPU_SWITCH
 | 
			
		||||
#undef REGISTER_CPU_REF_SWITCH
 | 
			
		||||
@ -311,7 +307,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
 | 
			
		||||
REGISTER_GPU_KERNEL(bool);
 | 
			
		||||
REGISTER_GPU_REF_KERNEL(bool);
 | 
			
		||||
REGISTER_GPU_KERNEL(uint64);
 | 
			
		||||
TF_CALL_variant(REGISTER_GPU_KERNEL);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_GPU_KERNEL
 | 
			
		||||
 | 
			
		||||
@ -220,8 +220,6 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
 | 
			
		||||
    break;
 | 
			
		||||
    TF_CALL_NUMBER_TYPES(CASE);
 | 
			
		||||
    TF_CALL_tstring(CASE);
 | 
			
		||||
    TF_CALL_uint32(CASE);
 | 
			
		||||
    TF_CALL_uint64(CASE);
 | 
			
		||||
    // TODO(feihugis): figure out how to support variant tensors.
 | 
			
		||||
#undef CASE
 | 
			
		||||
    default:
 | 
			
		||||
 | 
			
		||||
@ -98,7 +98,6 @@ typedef Eigen::SyclDevice SYCLDevice;
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
 | 
			
		||||
// uint32 not included in ALL_TYPES
 | 
			
		||||
TF_CALL_uint32(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
 | 
			
		||||
// quint16 not included in QUANTIZIED_TYPES
 | 
			
		||||
TF_CALL_quint16(REGISTER_KERNELS);
 | 
			
		||||
 | 
			
		||||
@ -164,8 +164,6 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared {
 | 
			
		||||
      DynamicPartitionOp<T>)
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
 | 
			
		||||
// For partitioning fingerprints.
 | 
			
		||||
TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION);
 | 
			
		||||
#undef REGISTER_DYNAMIC_PARTITION
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,8 @@ DEFINE_SETZERO_CPU(Eigen::half);
 | 
			
		||||
DEFINE_SETZERO_CPU(bfloat16);
 | 
			
		||||
DEFINE_SETZERO_CPU(float);
 | 
			
		||||
DEFINE_SETZERO_CPU(double);
 | 
			
		||||
DEFINE_SETZERO_CPU(uint32);
 | 
			
		||||
DEFINE_SETZERO_CPU(uint64);
 | 
			
		||||
DEFINE_SETZERO_CPU(uint8);
 | 
			
		||||
DEFINE_SETZERO_CPU(int8);
 | 
			
		||||
DEFINE_SETZERO_CPU(uint16);
 | 
			
		||||
@ -96,6 +98,8 @@ DEFINE_SETONE_CPU(Eigen::half);
 | 
			
		||||
DEFINE_SETONE_CPU(bfloat16);
 | 
			
		||||
DEFINE_SETONE_CPU(float);
 | 
			
		||||
DEFINE_SETONE_CPU(double);
 | 
			
		||||
DEFINE_SETONE_CPU(uint32);
 | 
			
		||||
DEFINE_SETONE_CPU(uint64);
 | 
			
		||||
DEFINE_SETONE_CPU(uint8);
 | 
			
		||||
DEFINE_SETONE_CPU(int8);
 | 
			
		||||
DEFINE_SETONE_CPU(uint16);
 | 
			
		||||
@ -137,7 +141,6 @@ struct FillFunctor<Eigen::ThreadPoolDevice, T> {
 | 
			
		||||
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
 | 
			
		||||
DEFINE_FILL_CPU(quint8);
 | 
			
		||||
DEFINE_FILL_CPU(quint16);
 | 
			
		||||
DEFINE_FILL_CPU(uint32);
 | 
			
		||||
#undef DEFINE_FILL_CPU
 | 
			
		||||
 | 
			
		||||
#ifdef TENSORFLOW_USE_SYCL
 | 
			
		||||
 | 
			
		||||
@ -211,8 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
 | 
			
		||||
TF_CALL_quint16(REGISTER_GATHER_CPU);
 | 
			
		||||
TF_CALL_qint16(REGISTER_GATHER_CPU);
 | 
			
		||||
TF_CALL_uint32(REGISTER_GATHER_CPU);
 | 
			
		||||
TF_CALL_uint64(REGISTER_GATHER_CPU);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_GATHER_CPU
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
 | 
			
		||||
 | 
			
		||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
 | 
			
		||||
REGISTER_GPU_KERNEL(Variant);
 | 
			
		||||
TF_CALL_uint32(REGISTER_GPU_KERNEL);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_GPU_KERNEL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -296,8 +296,6 @@ TF_CALL_tstring(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_quint16(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_qint16(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_uint32(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_uint64(REGISTER_CPU_KERNEL);
 | 
			
		||||
#undef REGISTER_CPU_KERNEL
 | 
			
		||||
#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -308,8 +308,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_quint16(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_qint16(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_uint32(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_uint64(REGISTER_KERNELS);
 | 
			
		||||
#undef REGISTER_KERNELS
 | 
			
		||||
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -561,8 +561,6 @@ TF_CALL_string(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_quint16(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_qint16(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_uint32(REGISTER_CPU_KERNEL);
 | 
			
		||||
TF_CALL_uint64(REGISTER_CPU_KERNEL);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_CPU_KERNEL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -213,8 +213,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_quint16(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_qint16(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_uint32(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_uint64(REGISTER_KERNELS);
 | 
			
		||||
#undef REGISTER_KERNELS
 | 
			
		||||
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -512,7 +512,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_uint32(REGISTER_KERNELS);
 | 
			
		||||
#undef REGISTER_KERNELS
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
@ -43,7 +43,6 @@ void Split<Eigen::ThreadPoolDevice, T, NDims>::operator()(
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
 | 
			
		||||
DEFINE_CPU_KERNELS(quint8)
 | 
			
		||||
DEFINE_CPU_KERNELS(uint64)
 | 
			
		||||
 | 
			
		||||
#ifdef TENSORFLOW_USE_SYCL
 | 
			
		||||
template <typename T, int NDims>
 | 
			
		||||
 | 
			
		||||
@ -404,7 +404,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
 | 
			
		||||
REGISTER_SPLIT(quint8);
 | 
			
		||||
REGISTER_SPLIT(uint64);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_SPLIT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -440,8 +440,6 @@ class StridedSliceAssignOp : public OpKernel {
 | 
			
		||||
                          StridedSliceAssignOp<CPUDevice, type, true>)
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
 | 
			
		||||
TF_CALL_uint32(REGISTER_STRIDED_SLICE);
 | 
			
		||||
TF_CALL_uint64(REGISTER_STRIDED_SLICE);
 | 
			
		||||
 | 
			
		||||
#undef REGISTER_STRIDED_SLICE
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -287,8 +287,6 @@ TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
 | 
			
		||||
#endif  // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
 | 
			
		||||
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
 | 
			
		||||
TF_CALL_uint32(DECLARE_FOR_N_CPU);
 | 
			
		||||
TF_CALL_uint64(DECLARE_FOR_N_CPU);
 | 
			
		||||
 | 
			
		||||
#ifdef TENSORFLOW_USE_SYCL
 | 
			
		||||
#define PREVENT_FOR_N_SYCL(T) \
 | 
			
		||||
 | 
			
		||||
@ -258,7 +258,6 @@ namespace functor {
 | 
			
		||||
 | 
			
		||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
 | 
			
		||||
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
 | 
			
		||||
TF_CALL_uint32(DECLARE_GPU_SPEC);
 | 
			
		||||
 | 
			
		||||
#undef DECLARE_GPU_SPEC
 | 
			
		||||
 | 
			
		||||
@ -276,7 +275,6 @@ TF_CALL_uint32(DECLARE_GPU_SPEC);
 | 
			
		||||
 | 
			
		||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
 | 
			
		||||
TF_CALL_uint32(REGISTER_KERNELS)
 | 
			
		||||
#undef REGISTER_KERNELS
 | 
			
		||||
 | 
			
		||||
#endif  // end GOOGLE_CUDA
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										28
									
								
								tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA
 | 
			
		||||
#define EIGEN_USE_GPU
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/kernels/topk_op.h"
 | 
			
		||||
#include "tensorflow/core/kernels/topk_op_gpu.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
using Eigen::GpuDevice;
 | 
			
		||||
 | 
			
		||||
template struct functor::TopKFunctor<GPUDevice, uint32>;
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // GOOGLE_CUDA
 | 
			
		||||
							
								
								
									
										28
									
								
								tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA
 | 
			
		||||
#define EIGEN_USE_GPU
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/kernels/topk_op.h"
 | 
			
		||||
#include "tensorflow/core/kernels/topk_op_gpu.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
using Eigen::GpuDevice;
 | 
			
		||||
 | 
			
		||||
template struct functor::TopKFunctor<GPUDevice, uint64>;
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // GOOGLE_CUDA
 | 
			
		||||
@ -182,8 +182,6 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
 | 
			
		||||
  switch (element.dtype()) {
 | 
			
		||||
    TF_CALL_ALL_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint32(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint64(HANDLE_TYPE);
 | 
			
		||||
#undef HANDLE_TYPE
 | 
			
		||||
    default:
 | 
			
		||||
      return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
 | 
			
		||||
@ -207,8 +205,6 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
 | 
			
		||||
  switch (parent.dtype()) {
 | 
			
		||||
    TF_CALL_ALL_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint32(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint64(HANDLE_TYPE);
 | 
			
		||||
#undef HANDLE_TYPE
 | 
			
		||||
    default:
 | 
			
		||||
      return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
 | 
			
		||||
@ -280,8 +276,6 @@ Status CopyContiguousSlices(const Tensor& src, int64 src_offset,
 | 
			
		||||
  switch (src.dtype()) {
 | 
			
		||||
    TF_CALL_ALL_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint32(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint64(HANDLE_TYPE);
 | 
			
		||||
#undef HANDLE_TYPE
 | 
			
		||||
    default:
 | 
			
		||||
      return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
 | 
			
		||||
@ -308,8 +302,6 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
 | 
			
		||||
  switch (parent->dtype()) {
 | 
			
		||||
    TF_CALL_ALL_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint32(HANDLE_TYPE);
 | 
			
		||||
    TF_CALL_uint64(HANDLE_TYPE);
 | 
			
		||||
#undef HANDLE_TYPE
 | 
			
		||||
    default:
 | 
			
		||||
      return errors::Unimplemented(
 | 
			
		||||
 | 
			
		||||
@ -116,7 +116,9 @@ TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
 | 
			
		||||
TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user