Register GPU kernels for EinsumOp.
This change is mostly boilerplate to make the nvcc compiler happy. Most of the heavy lifting in EinsumOp is done by BatchMatmul/Reduction functors and Eigen Tensor Ops which already have GPU kernels defined for them. This lets us easily obtain an efficient implementation on the GPU. PiperOrigin-RevId: 253736113
This commit is contained in:
parent
ac1a9ebdbc
commit
222df6844f
@ -14,6 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/einsum_op.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
@ -34,9 +39,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/einsum_op_util.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/reduction_ops_common_gpu.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -368,17 +378,21 @@ Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input,
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
switch (reshape.size()) {
|
||||
#define NDIMS_CASE(N) \
|
||||
case N: { \
|
||||
if (should_inflate) { \
|
||||
auto output_map = output->shaped<T, N>(reshape); \
|
||||
auto input_map = input.shaped<T, N>(strided_shape); \
|
||||
output_map.device(device) = input_map.inflate(strides); \
|
||||
} else { \
|
||||
auto input_map = input.shaped<T, N>(reshape); \
|
||||
auto output_map = output->shaped<T, N>(strided_shape); \
|
||||
output_map.device(device) = input_map.stride(strides); \
|
||||
} \
|
||||
#define NDIMS_CASE(N) \
|
||||
case N: { \
|
||||
if (should_inflate) { \
|
||||
auto output_map = output->shaped<T, N>(reshape); \
|
||||
auto input_map = input.shaped<T, N>(strided_shape); \
|
||||
functor::InflateFunctor<Device, T, N>()( \
|
||||
device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \
|
||||
output_map); \
|
||||
} else { \
|
||||
auto input_map = input.shaped<T, N>(reshape); \
|
||||
auto output_map = output->shaped<T, N>(strided_shape); \
|
||||
functor::StrideFunctor<Device, T, N>()( \
|
||||
device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \
|
||||
output_map); \
|
||||
} \
|
||||
} break;
|
||||
NDIMS_CASE(1);
|
||||
NDIMS_CASE(2);
|
||||
@ -695,18 +709,59 @@ class EinsumOp : public OpKernel {
|
||||
bool output_has_ellipsis_ = false;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T, N) \
|
||||
template <> \
|
||||
void StrideFunctor<GPUDevice, T, N>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, N>& strides, \
|
||||
typename TTypes<T, N>::Tensor output); \
|
||||
extern template struct StrideFunctor<GPUDevice, T, N>; \
|
||||
template <> \
|
||||
void InflateFunctor<GPUDevice, T, N>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, N>& strides, \
|
||||
typename TTypes<T, N>::Tensor output); \
|
||||
extern template struct InflateFunctor<GPUDevice, T, N>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPEC(T, 1); \
|
||||
DECLARE_GPU_SPEC(T, 2); \
|
||||
DECLARE_GPU_SPEC(T, 3); \
|
||||
DECLARE_GPU_SPEC(T, 4); \
|
||||
DECLARE_GPU_SPEC(T, 5); \
|
||||
DECLARE_GPU_SPEC(T, 6);
|
||||
|
||||
DECLARE_GPU_SPECS(double);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
#undef DECLARE_GPU_SPECS
|
||||
} // namespace functor
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#define REGISTER_EINSUM(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||
EinsumOp<D##Device, TYPE>);
|
||||
|
||||
// TODO(anudhyan): Also register GPU kernels for Einsum.
|
||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_complex64(REGISTER_CPU);
|
||||
TF_CALL_complex128(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE)
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef REGISTER_EINSUM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
48
tensorflow/core/kernels/einsum_op.h
Normal file
48
tensorflow/core/kernels/einsum_op.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T, int N>
|
||||
struct StrideFunctor {
|
||||
void operator()(const Device& d, typename TTypes<T, N>::ConstTensor input,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, N>& strides,
|
||||
typename TTypes<T, N>::Tensor output) {
|
||||
output.device(d) = input.stride(strides);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int N>
|
||||
struct InflateFunctor {
|
||||
void operator()(const Device& d, typename TTypes<T, N>::ConstTensor input,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, N>& strides,
|
||||
typename TTypes<T, N>::Tensor output) {
|
||||
output.device(d) = input.inflate(strides);
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
46
tensorflow/core/kernels/einsum_op_gpu.cu.cc
Normal file
46
tensorflow/core/kernels/einsum_op_gpu.cu.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/einsum_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define DECLARE_GPU_SPECS_NDIM(T, NDIM) \
|
||||
template struct functor::StrideFunctor<Eigen::GpuDevice, T, NDIM>; \
|
||||
template struct functor::InflateFunctor<Eigen::GpuDevice, T, NDIM>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 1); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 2); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 3); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 4); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 5); \
|
||||
DECLARE_GPU_SPECS_NDIM(T, 6);
|
||||
|
||||
TF_CALL_float(DECLARE_GPU_SPECS);
|
||||
TF_CALL_double(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS_NDIM
|
||||
#undef DECLARE_GPU_SPECS
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
44
tensorflow/core/kernels/reduction_ops_common_gpu.h
Normal file
44
tensorflow/core/kernels/reduction_ops_common_gpu.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
|
||||
|
||||
#if !GOOGLE_CUDA
|
||||
#error This file must only be included when building with Cuda support
|
||||
#endif
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
template <typename Reducer>
|
||||
struct ReduceFunctor<Eigen::GpuDevice, Reducer> {
|
||||
template <typename OUT_T, typename IN_T, typename ReductionAxes>
|
||||
static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
|
||||
const ReductionAxes& reduction_axes,
|
||||
const Reducer& reducer);
|
||||
|
||||
template <typename OUT_T>
|
||||
static void FillIdentity(const Eigen::GpuDevice& d, OUT_T out,
|
||||
const Reducer& reducer);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
|
Loading…
Reference in New Issue
Block a user