diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c749dd21fbf..29669cfa45b 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3757,6 +3757,7 @@ tf_kernel_library( "scan_ops_gpu_double.cu.cc", "scan_ops_gpu_float.cu.cc", "scan_ops_gpu_half.cu.cc", + "scan_ops_gpu_int.cu.cc", ], deps = MATH_DEPS + if_cuda([ "@cub_archive//:cub", diff --git a/tensorflow/core/kernels/scan_ops.cc b/tensorflow/core/kernels/scan_ops.cc index ea42fdefb41..126be3c5147 100644 --- a/tensorflow/core/kernels/scan_ops.cc +++ b/tensorflow/core/kernels/scan_ops.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" -#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/scan_ops.h" @@ -106,6 +105,8 @@ namespace functor { DECLARE(Eigen::internal::ProdReducer, T); TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS); +DECLARE_FOR_ALL_REDUCERS(int32); +DECLARE_FOR_ALL_REDUCERS(int64); #undef DECLARE_FOR_ALL_REDUCERS #undef DECLARE @@ -147,6 +148,8 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .HostMemory("axis"), \ ScanOp, int64>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) +REGISTER_GPU_KERNELS(int32); +REGISTER_GPU_KERNELS(int64); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -184,6 +187,8 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .HostMemory("axis"), \ ScanOp, int64>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) +REGISTER_GPU_KERNELS(int32); +REGISTER_GPU_KERNELS(int64); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/scan_ops_gpu_int.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_int.cu.cc new file mode 100644 index 00000000000..e9d0262ff7d --- /dev/null +++ b/tensorflow/core/kernels/scan_ops_gpu_int.cu.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/scan_ops.h" +#include "tensorflow/core/kernels/scan_ops_gpu.h" + +namespace tensorflow { +using Eigen::GpuDevice; +template struct functor::Scan, + int64>; +template struct functor::Scan, + int64>; +template struct functor::Scan, + int32>; +template struct functor::Scan, + int32>; +} // namespace tensorflow + +#endif // GOOGLE_CUDA