From 222df6844f8621706049bfd9f7e16cbadf72e8ed Mon Sep 17 00:00:00 2001 From: Anudhyan Boral Date: Tue, 18 Jun 2019 00:13:59 -0700 Subject: [PATCH] Register GPU kernels for EinsumOp. This change is mostly boilerplate to make the nvcc compiler happy. Most of the heavy lifting in EinsumOp is done by BatchMatmul/Reduction functors and Eigen Tensor Ops which already have GPU kernels defined for them. This lets us easily obtain an efficient implementation on the GPU. PiperOrigin-RevId: 253736113 --- tensorflow/core/kernels/einsum_op.cc | 79 ++++++++++++++++--- tensorflow/core/kernels/einsum_op.h | 48 +++++++++++ tensorflow/core/kernels/einsum_op_gpu.cu.cc | 46 +++++++++++ .../core/kernels/reduction_ops_common_gpu.h | 44 +++++++++++ 4 files changed, 205 insertions(+), 12 deletions(-) create mode 100644 tensorflow/core/kernels/einsum_op.h create mode 100644 tensorflow/core/kernels/einsum_op_gpu.cu.cc create mode 100644 tensorflow/core/kernels/reduction_ops_common_gpu.h diff --git a/tensorflow/core/kernels/einsum_op.cc b/tensorflow/core/kernels/einsum_op.cc index ae5733c19d6..db968d4b89e 100644 --- a/tensorflow/core/kernels/einsum_op.cc +++ b/tensorflow/core/kernels/einsum_op.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "tensorflow/core/kernels/einsum_op.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" @@ -34,9 +39,14 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/einsum_op_util.h" +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/reduction_ops_common_gpu.h" +#endif // GOOGLE_CUDA + namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; namespace { @@ -368,17 +378,21 @@ Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); const Device& device = ctx->eigen_device(); switch (reshape.size()) { -#define NDIMS_CASE(N) \ - case N: { \ - if (should_inflate) { \ - auto output_map = output->shaped(reshape); \ - auto input_map = input.shaped(strided_shape); \ - output_map.device(device) = input_map.inflate(strides); \ - } else { \ - auto input_map = input.shaped(reshape); \ - auto output_map = output->shaped(strided_shape); \ - output_map.device(device) = input_map.stride(strides); \ - } \ +#define NDIMS_CASE(N) \ + case N: { \ + if (should_inflate) { \ + auto output_map = output->shaped(reshape); \ + auto input_map = input.shaped(strided_shape); \ + functor::InflateFunctor()( \ + device, input_map, TensorShape(strides).AsEigenDSizes(), \ + output_map); \ + } else { \ + auto input_map = input.shaped(reshape); \ + auto output_map = output->shaped(strided_shape); \ + functor::StrideFunctor()( \ + device, input_map, TensorShape(strides).AsEigenDSizes(), \ + output_map); \ + } \ } break; NDIMS_CASE(1); NDIMS_CASE(2); @@ -695,18 +709,59 @@ class EinsumOp : public OpKernel { bool output_has_ellipsis_ = false; }; +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, N) \ + template <> \ + void StrideFunctor::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + typename TTypes::Tensor output); \ + extern template struct StrideFunctor; \ + template <> \ + void InflateFunctor::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + typename TTypes::Tensor output); \ + extern template struct InflateFunctor; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); \ + DECLARE_GPU_SPEC(T, 6); + +DECLARE_GPU_SPECS(double); +DECLARE_GPU_SPECS(float); +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPECS +} // namespace functor +#endif // GOOGLE_CUDA + #define REGISTER_EINSUM(D, TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("Einsum").Device(DEVICE_##D).TypeConstraint("T"), \ EinsumOp); -// TODO(anudhyan): Also register GPU kernels for Einsum. #define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE) TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); #undef REGISTER_CPU + +#if GOOGLE_CUDA +#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE) +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +TF_CALL_complex64(REGISTER_GPU); +TF_CALL_complex128(REGISTER_GPU); +#undef REGISTER_GPU +#endif // GOOGLE_CUDA + #undef REGISTER_EINSUM } // namespace tensorflow diff --git a/tensorflow/core/kernels/einsum_op.h b/tensorflow/core/kernels/einsum_op.h new file mode 100644 index 00000000000..8ac1bbc5fe5 --- /dev/null +++ b/tensorflow/core/kernels/einsum_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_ +#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +namespace tensorflow { +namespace functor { + +template +struct StrideFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + output.device(d) = input.stride(strides); + } +}; + +template +struct InflateFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + output.device(d) = input.inflate(strides); + } +}; +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_ diff --git a/tensorflow/core/kernels/einsum_op_gpu.cu.cc b/tensorflow/core/kernels/einsum_op_gpu.cu.cc new file mode 100644 index 00000000000..e7adbe571e7 --- /dev/null +++ b/tensorflow/core/kernels/einsum_op_gpu.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/einsum_op.h" + +namespace tensorflow { + +#define DECLARE_GPU_SPECS_NDIM(T, NDIM) \ + template struct functor::StrideFunctor; \ + template struct functor::InflateFunctor; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPECS_NDIM(T, 1); \ + DECLARE_GPU_SPECS_NDIM(T, 2); \ + DECLARE_GPU_SPECS_NDIM(T, 3); \ + DECLARE_GPU_SPECS_NDIM(T, 4); \ + DECLARE_GPU_SPECS_NDIM(T, 5); \ + DECLARE_GPU_SPECS_NDIM(T, 6); + +TF_CALL_float(DECLARE_GPU_SPECS); +TF_CALL_double(DECLARE_GPU_SPECS); +TF_CALL_complex64(DECLARE_GPU_SPECS); +TF_CALL_complex128(DECLARE_GPU_SPECS); + +#undef DECLARE_GPU_SPECS_NDIM +#undef DECLARE_GPU_SPECS + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/reduction_ops_common_gpu.h b/tensorflow/core/kernels/reduction_ops_common_gpu.h new file mode 100644 index 00000000000..9af43f885f9 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_common_gpu.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_ + +#if !GOOGLE_CUDA +#error This file must only be included when building with Cuda support +#endif + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer); + + template + static void FillIdentity(const Eigen::GpuDevice& d, OUT_T out, + const Reducer& reducer); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_