diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3df63d34e36..b934c64bfc5 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3279,7 +3279,15 @@ tf_kernel_library( tf_kernel_library( name = "scan_ops", - prefix = "scan_ops", + srcs = ["scan_ops.cc"], + hdrs = ["scan_ops.h"], + gpu_srcs = [ + "scan_ops.h", + "scan_ops_gpu.h", + "scan_ops_gpu_double.cu.cc", + "scan_ops_gpu_float.cu.cc", + "scan_ops_gpu_half.cu.cc", + ], deps = MATH_DEPS + if_cuda(["@cub_archive//:cub"]), ) diff --git a/tensorflow/core/kernels/scan_ops_gpu.cu.cc b/tensorflow/core/kernels/scan_ops_gpu.h similarity index 97% rename from tensorflow/core/kernels/scan_ops_gpu.cu.cc rename to tensorflow/core/kernels/scan_ops_gpu.h index ed66c02dc58..976b2215405 100644 --- a/tensorflow/core/kernels/scan_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/scan_ops_gpu.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ + #if GOOGLE_CUDA #define EIGEN_USE_GPU @@ -290,17 +293,8 @@ struct Scan, T> { }; } // namespace functor - -#define DEFINE(REDUCER, T) template struct functor::Scan; - -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE(Eigen::internal::SumReducer, T); \ - DEFINE(Eigen::internal::ProdReducer, T); - -TF_CALL_GPU_NUMBER_TYPES(DEFINE_FOR_ALL_REDUCERS); -#undef DEFINE_FOR_ALL_REDUCERS -#undef DEFINE - } // end namespace tensorflow #endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ diff --git a/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc new file mode 100644 index 00000000000..adce37e473c --- /dev/null +++ b/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/scan_ops.h" +#include "tensorflow/core/kernels/scan_ops_gpu.h" + +namespace tensorflow { +using Eigen::GpuDevice; +template struct functor::Scan, + double>; +template struct functor::Scan, + double>; +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc new file mode 100644 index 00000000000..b72415822d0 --- /dev/null +++ b/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/scan_ops.h" +#include "tensorflow/core/kernels/scan_ops_gpu.h" + +namespace tensorflow { +using Eigen::GpuDevice; +template struct functor::Scan, + float>; +template struct functor::Scan, + float>; +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc new file mode 100644 index 00000000000..f9fb528be98 --- /dev/null +++ b/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/scan_ops.h" +#include "tensorflow/core/kernels/scan_ops_gpu.h" + +namespace tensorflow { +using Eigen::GpuDevice; +template struct functor::Scan< + GpuDevice, Eigen::internal::SumReducer, Eigen::half>; +template struct functor::Scan< + GpuDevice, Eigen::internal::ProdReducer, Eigen::half>; +} // namespace tensorflow + +#endif // GOOGLE_CUDA