From e972c5572634efd188696038e9241b75cdcd69bc Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Fri, 19 Jun 2020 00:07:20 -0700 Subject: [PATCH] Add uint32 & uint64 to TF_CALL_INTEGRAL_TYPES Both uint32 & uint64 had been omitted from TF_CALL_INTEGRAL_TYPES due to suggested concerns of size bloat. In reality it seems that the size increase is only around 2MB. Further, this fixes #39649 since we are no longer inadvertently using the XLA_CPU device to perform tf.reduce_mean. PiperOrigin-RevId: 317259372 Change-Id: Iacf75eaedce198fbef4bd9fd59b6fefa584cbf34 --- tensorflow/core/framework/register_types.h | 21 +++++--------- tensorflow/core/framework/types.cc | 5 ---- tensorflow/core/kernels/BUILD | 2 ++ tensorflow/core/kernels/concat_lib_cpu.cc | 2 -- tensorflow/core/kernels/concat_op.cc | 2 -- tensorflow/core/kernels/constant_op.cc | 1 - tensorflow/core/kernels/control_flow_ops.cc | 5 ---- .../core/kernels/data/dataset_test_base.cc | 2 -- tensorflow/core/kernels/dense_update_ops.cc | 1 - .../core/kernels/dynamic_partition_op.cc | 2 -- tensorflow/core/kernels/fill_functor.cc | 5 +++- tensorflow/core/kernels/gather_op.cc | 2 -- tensorflow/core/kernels/identity_op.cc | 1 - tensorflow/core/kernels/ragged_gather_op.cc | 2 -- .../kernels/ragged_tensor_from_variant_op.cc | 2 -- .../kernels/ragged_tensor_to_tensor_op.cc | 2 -- .../kernels/ragged_tensor_to_variant_op.cc | 2 -- .../core/kernels/resource_variable_ops.cc | 1 - tensorflow/core/kernels/split_lib_cpu.cc | 1 - tensorflow/core/kernels/split_op.cc | 1 - tensorflow/core/kernels/strided_slice_op.cc | 2 -- .../core/kernels/strided_slice_op_impl.h | 2 -- tensorflow/core/kernels/topk_op.cc | 2 -- .../core/kernels/topk_op_gpu_uint32.cu.cc | 28 +++++++++++++++++++ .../core/kernels/topk_op_gpu_uint64.cu.cc | 28 +++++++++++++++++++ tensorflow/core/util/batch_util.cc | 8 ------ .../core/util/saved_tensor_slice_util.h | 2 ++ 27 files changed, 71 insertions(+), 63 deletions(-) create mode 100644 tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc create mode 100644 tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index bc3e5e1743b..0cf6536e8c2 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -153,16 +153,9 @@ limitations under the License. #endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines // Defines for sets of types. - -// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES. -// -// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and -// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in -// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the -// TF binary size and performance. -#define TF_CALL_INTEGRAL_TYPES(m) \ - TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ - TF_CALL_uint8(m) TF_CALL_int8(m) +#define TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \ + TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m) #define TF_CALL_FLOAT_TYPES(m) \ TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) @@ -174,10 +167,10 @@ limitations under the License. #define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) -#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ - TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ - TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \ - TF_CALL_int8(m) +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ + TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \ + TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m) #define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index 97eaec98ffe..d6455e012d0 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -238,11 +238,6 @@ int DataTypeSize(DataType dt) { TF_CALL_qint16(CASE); TF_CALL_quint16(CASE); - // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we - // don't want to define kernels for them at this stage to avoid binary - // bloat. - TF_CALL_uint32(CASE); - TF_CALL_uint64(CASE); default: return 0; } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 279dff92c58..97f974c6af4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4900,7 +4900,9 @@ tf_kernel_library( "topk_op_gpu_double.cu.cc", "topk_op_gpu_float.cu.cc", "topk_op_gpu_half.cu.cc", + "topk_op_gpu_uint64.cu.cc", "topk_op_gpu_int64.cu.cc", + "topk_op_gpu_uint32.cu.cc", "topk_op_gpu_int32.cu.cc", "topk_op_gpu_int16.cu.cc", "topk_op_gpu_uint16.cu.cc", diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index da73d3d2c56..1dec589d3ff 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -116,8 +116,6 @@ REGISTER(qint8) REGISTER(quint16) REGISTER(qint16) REGISTER(qint32) -REGISTER(uint32) -REGISTER(uint64) #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ !defined(__ANDROID_TYPES_FULL__) diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index be3e9a67c5f..d3f3a04f33b 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -208,8 +208,6 @@ REGISTER_CONCAT(qint8); REGISTER_CONCAT(quint16); REGISTER_CONCAT(qint16); REGISTER_CONCAT(qint32); -REGISTER_CONCAT(uint32); -REGISTER_CONCAT(uint64); #undef REGISTER_CONCAT diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 4bcbc076446..dc178d17d49 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -211,7 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); // the conversion from uint8 to quint8. REGISTER_KERNEL(CPU, quint8); REGISTER_KERNEL(CPU, quint16); -REGISTER_KERNEL(CPU, uint32); #undef REGISTER_CPU_KERNEL #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index c8e83b6f672..accb2c59540 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -101,16 +101,12 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH); TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH); -REGISTER_CPU_SWITCH(uint64); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH); -REGISTER_GPU_SWITCH(uint64); TF_CALL_variant(REGISTER_GPU_SWITCH); -TF_CALL_uint32(REGISTER_GPU_SWITCH); -TF_CALL_uint32(REGISTER_GPU_REF_SWITCH); #undef REGISTER_CPU_SWITCH #undef REGISTER_CPU_REF_SWITCH @@ -311,7 +307,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL); REGISTER_GPU_KERNEL(bool); REGISTER_GPU_REF_KERNEL(bool); -REGISTER_GPU_KERNEL(uint64); TF_CALL_variant(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index b91ab9b733c..e41e35be1e9 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -220,8 +220,6 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) { break; TF_CALL_NUMBER_TYPES(CASE); TF_CALL_tstring(CASE); - TF_CALL_uint32(CASE); - TF_CALL_uint64(CASE); // TODO(feihugis): figure out how to support variant tensors. #undef CASE default: diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 55e4cd7606a..71235fca143 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -98,7 +98,6 @@ typedef Eigen::SyclDevice SYCLDevice; TF_CALL_ALL_TYPES(REGISTER_KERNELS); // uint32 not included in ALL_TYPES -TF_CALL_uint32(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); // quint16 not included in QUANTIZIED_TYPES TF_CALL_quint16(REGISTER_KERNELS); diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc index 90ed71dccce..95af19c4c48 100644 --- a/tensorflow/core/kernels/dynamic_partition_op.cc +++ b/tensorflow/core/kernels/dynamic_partition_op.cc @@ -164,8 +164,6 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { DynamicPartitionOp) TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION); -// For partitioning fingerprints. -TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION); #undef REGISTER_DYNAMIC_PARTITION } // namespace tensorflow diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index 10dd3df1915..174a4e45a79 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -45,6 +45,8 @@ DEFINE_SETZERO_CPU(Eigen::half); DEFINE_SETZERO_CPU(bfloat16); DEFINE_SETZERO_CPU(float); DEFINE_SETZERO_CPU(double); +DEFINE_SETZERO_CPU(uint32); +DEFINE_SETZERO_CPU(uint64); DEFINE_SETZERO_CPU(uint8); DEFINE_SETZERO_CPU(int8); DEFINE_SETZERO_CPU(uint16); @@ -96,6 +98,8 @@ DEFINE_SETONE_CPU(Eigen::half); DEFINE_SETONE_CPU(bfloat16); DEFINE_SETONE_CPU(float); DEFINE_SETONE_CPU(double); +DEFINE_SETONE_CPU(uint32); +DEFINE_SETONE_CPU(uint64); DEFINE_SETONE_CPU(uint8); DEFINE_SETONE_CPU(int8); DEFINE_SETONE_CPU(uint16); @@ -137,7 +141,6 @@ struct FillFunctor { TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); DEFINE_FILL_CPU(quint8); DEFINE_FILL_CPU(quint16); -DEFINE_FILL_CPU(uint32); #undef DEFINE_FILL_CPU #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 6d493a5f2ea..948567e019a 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -211,8 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); TF_CALL_quint16(REGISTER_GATHER_CPU); TF_CALL_qint16(REGISTER_GATHER_CPU); -TF_CALL_uint32(REGISTER_GATHER_CPU); -TF_CALL_uint64(REGISTER_GATHER_CPU); #undef REGISTER_GATHER_CPU diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index fd94df9a768..daa8a1ddb25 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -122,7 +122,6 @@ REGISTER_SYCL_HOST_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); REGISTER_GPU_KERNEL(Variant); -TF_CALL_uint32(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/core/kernels/ragged_gather_op.cc b/tensorflow/core/kernels/ragged_gather_op.cc index 88c0d1ebd69..3bf82cba050 100644 --- a/tensorflow/core/kernels/ragged_gather_op.cc +++ b/tensorflow/core/kernels/ragged_gather_op.cc @@ -296,8 +296,6 @@ TF_CALL_tstring(REGISTER_CPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); TF_CALL_quint16(REGISTER_CPU_KERNEL); TF_CALL_qint16(REGISTER_CPU_KERNEL); -TF_CALL_uint32(REGISTER_CPU_KERNEL); -TF_CALL_uint64(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL #undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index f83bcb38c6c..ad0712e6fd0 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -308,8 +308,6 @@ TF_CALL_tstring(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_quint16(REGISTER_KERNELS); TF_CALL_qint16(REGISTER_KERNELS); -TF_CALL_uint32(REGISTER_KERNELS); -TF_CALL_uint64(REGISTER_KERNELS); #undef REGISTER_KERNELS #undef REGISTER_KERNELS_WITH_SPLIT_TYPE } // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc index d729c43f25a..9ae5d7ffbdc 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc @@ -561,8 +561,6 @@ TF_CALL_string(REGISTER_CPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); TF_CALL_quint16(REGISTER_CPU_KERNEL); TF_CALL_qint16(REGISTER_CPU_KERNEL); -TF_CALL_uint32(REGISTER_CPU_KERNEL); -TF_CALL_uint64(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index 7a5ae1c6240..64c372b005e 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -213,8 +213,6 @@ TF_CALL_tstring(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_quint16(REGISTER_KERNELS); TF_CALL_qint16(REGISTER_KERNELS); -TF_CALL_uint32(REGISTER_KERNELS); -TF_CALL_uint64(REGISTER_KERNELS); #undef REGISTER_KERNELS #undef REGISTER_KERNELS_WITH_SPLIT_TYPE } // namespace tensorflow diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 0fc1d53749f..79a64cb9219 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -512,7 +512,6 @@ class AssignVariableOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); -TF_CALL_uint32(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/split_lib_cpu.cc b/tensorflow/core/kernels/split_lib_cpu.cc index 0cb0a94d498..a3060e4e90d 100644 --- a/tensorflow/core/kernels/split_lib_cpu.cc +++ b/tensorflow/core/kernels/split_lib_cpu.cc @@ -43,7 +43,6 @@ void Split::operator()( TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS) DEFINE_CPU_KERNELS(quint8) -DEFINE_CPU_KERNELS(uint64) #ifdef TENSORFLOW_USE_SYCL template diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index f09740c6198..08575f01f67 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -404,7 +404,6 @@ class SplitOpSYCL : public SplitOpBase { TF_CALL_ALL_TYPES(REGISTER_SPLIT); REGISTER_SPLIT(quint8); -REGISTER_SPLIT(uint64); #undef REGISTER_SPLIT diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index ccc1984bb98..b4099213303 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -440,8 +440,6 @@ class StridedSliceAssignOp : public OpKernel { StridedSliceAssignOp) TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); -TF_CALL_uint32(REGISTER_STRIDED_SLICE); -TF_CALL_uint64(REGISTER_STRIDED_SLICE); #undef REGISTER_STRIDED_SLICE diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index 1ae959b7b3f..5ce1d773e33 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -287,8 +287,6 @@ TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU); #endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); -TF_CALL_uint32(DECLARE_FOR_N_CPU); -TF_CALL_uint64(DECLARE_FOR_N_CPU); #ifdef TENSORFLOW_USE_SYCL #define PREVENT_FOR_N_SYCL(T) \ diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index c555b42f005..50325b7bcfe 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -258,7 +258,6 @@ namespace functor { TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); -TF_CALL_uint32(DECLARE_GPU_SPEC); #undef DECLARE_GPU_SPEC @@ -276,7 +275,6 @@ TF_CALL_uint32(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); -TF_CALL_uint32(REGISTER_KERNELS) #undef REGISTER_KERNELS #endif // end GOOGLE_CUDA diff --git a/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc b/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc new file mode 100644 index 00000000000..16e2e0e9420 --- /dev/null +++ b/tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/topk_op.h" +#include "tensorflow/core/kernels/topk_op_gpu.h" + +namespace tensorflow { +using Eigen::GpuDevice; + +template struct functor::TopKFunctor; +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc b/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc new file mode 100644 index 00000000000..895247a63a2 --- /dev/null +++ b/tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/topk_op.h" +#include "tensorflow/core/kernels/topk_op_gpu.h" + +namespace tensorflow { +using Eigen::GpuDevice; + +template struct functor::TopKFunctor; +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/util/batch_util.cc b/tensorflow/core/util/batch_util.cc index b88c365ced0..e03188b04da 100644 --- a/tensorflow/core/util/batch_util.cc +++ b/tensorflow/core/util/batch_util.cc @@ -182,8 +182,6 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { switch (element.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_uint32(HANDLE_TYPE); - TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", @@ -207,8 +205,6 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { switch (parent.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_uint32(HANDLE_TYPE); - TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopySliceToElement Unhandled data type: ", @@ -280,8 +276,6 @@ Status CopyContiguousSlices(const Tensor& src, int64 src_offset, switch (src.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_uint32(HANDLE_TYPE); - TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopyContiguousSlices unhandled data type: ", @@ -308,8 +302,6 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) { switch (parent->dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_uint32(HANDLE_TYPE); - TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented( diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h index 09b9235b711..1f9768f5163 100644 --- a/tensorflow/core/util/saved_tensor_slice_util.h +++ b/tensorflow/core/util/saved_tensor_slice_util.h @@ -116,7 +116,9 @@ TENSOR_PROTO_EXTRACT_TYPE(double, double, double); TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float); TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double); TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); +TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32); TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64); +TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64); TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32); TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);