Adding integer type GPU kernels to tf.cumsum and tf.cumprod.

PiperOrigin-RevId: 253327313
This commit is contained in:
A. Unique TensorFlower 2019-06-14 17:35:12 -07:00 committed by TensorFlower Gardener
parent 1ae09760d1
commit db3ecd34a4
3 changed files with 42 additions and 1 deletions

View File

@ -3757,6 +3757,7 @@ tf_kernel_library(
"scan_ops_gpu_double.cu.cc",
"scan_ops_gpu_float.cu.cc",
"scan_ops_gpu_half.cu.cc",
"scan_ops_gpu_int.cu.cc",
],
deps = MATH_DEPS + if_cuda([
"@cub_archive//:cub",

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/scan_ops.h"
@ -106,6 +105,8 @@ namespace functor {
DECLARE(Eigen::internal::ProdReducer<T>, T);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
DECLARE_FOR_ALL_REDUCERS(int32);
DECLARE_FOR_ALL_REDUCERS(int64);
#undef DECLARE_FOR_ALL_REDUCERS
#undef DECLARE
@ -147,6 +148,8 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
REGISTER_GPU_KERNELS(int32);
REGISTER_GPU_KERNELS(int64);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
@ -184,6 +187,8 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.HostMemory("axis"), \
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
REGISTER_GPU_KERNELS(int32);
REGISTER_GPU_KERNELS(int64);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,35 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/scan_ops.h"
#include "tensorflow/core/kernels/scan_ops_gpu.h"
namespace tensorflow {
using Eigen::GpuDevice;
template struct functor::Scan<GpuDevice, Eigen::internal::SumReducer<int64>,
int64>;
template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<int64>,
int64>;
template struct functor::Scan<GpuDevice, Eigen::internal::SumReducer<int32>,
int32>;
template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<int32>,
int32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA