Register GPU kernels for EinsumOp.
This change is mostly boilerplate to make the nvcc compiler happy. Most of the heavy lifting in EinsumOp is done by BatchMatmul/Reduction functors and Eigen Tensor Ops which already have GPU kernels defined for them. This lets us easily obtain an efficient implementation on the GPU. PiperOrigin-RevId: 253736113
This commit is contained in:
parent
ac1a9ebdbc
commit
222df6844f
@ -14,6 +14,11 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/einsum_op.h"
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
@ -34,9 +39,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/einsum_op_util.h"
|
#include "tensorflow/core/util/einsum_op_util.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/reduction_ops_common_gpu.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -373,11 +383,15 @@ Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input,
|
|||||||
if (should_inflate) { \
|
if (should_inflate) { \
|
||||||
auto output_map = output->shaped<T, N>(reshape); \
|
auto output_map = output->shaped<T, N>(reshape); \
|
||||||
auto input_map = input.shaped<T, N>(strided_shape); \
|
auto input_map = input.shaped<T, N>(strided_shape); \
|
||||||
output_map.device(device) = input_map.inflate(strides); \
|
functor::InflateFunctor<Device, T, N>()( \
|
||||||
|
device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \
|
||||||
|
output_map); \
|
||||||
} else { \
|
} else { \
|
||||||
auto input_map = input.shaped<T, N>(reshape); \
|
auto input_map = input.shaped<T, N>(reshape); \
|
||||||
auto output_map = output->shaped<T, N>(strided_shape); \
|
auto output_map = output->shaped<T, N>(strided_shape); \
|
||||||
output_map.device(device) = input_map.stride(strides); \
|
functor::StrideFunctor<Device, T, N>()( \
|
||||||
|
device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \
|
||||||
|
output_map); \
|
||||||
} \
|
} \
|
||||||
} break;
|
} break;
|
||||||
NDIMS_CASE(1);
|
NDIMS_CASE(1);
|
||||||
@ -695,18 +709,59 @@ class EinsumOp : public OpKernel {
|
|||||||
bool output_has_ellipsis_ = false;
|
bool output_has_ellipsis_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T, N) \
|
||||||
|
template <> \
|
||||||
|
void StrideFunctor<GPUDevice, T, N>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, N>& strides, \
|
||||||
|
typename TTypes<T, N>::Tensor output); \
|
||||||
|
extern template struct StrideFunctor<GPUDevice, T, N>; \
|
||||||
|
template <> \
|
||||||
|
void InflateFunctor<GPUDevice, T, N>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, N>& strides, \
|
||||||
|
typename TTypes<T, N>::Tensor output); \
|
||||||
|
extern template struct InflateFunctor<GPUDevice, T, N>;
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
|
DECLARE_GPU_SPEC(T, 1); \
|
||||||
|
DECLARE_GPU_SPEC(T, 2); \
|
||||||
|
DECLARE_GPU_SPEC(T, 3); \
|
||||||
|
DECLARE_GPU_SPEC(T, 4); \
|
||||||
|
DECLARE_GPU_SPEC(T, 5); \
|
||||||
|
DECLARE_GPU_SPEC(T, 6);
|
||||||
|
|
||||||
|
DECLARE_GPU_SPECS(double);
|
||||||
|
DECLARE_GPU_SPECS(float);
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
#undef DECLARE_GPU_SPECS
|
||||||
|
} // namespace functor
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#define REGISTER_EINSUM(D, TYPE) \
|
#define REGISTER_EINSUM(D, TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
|
||||||
EinsumOp<D##Device, TYPE>);
|
EinsumOp<D##Device, TYPE>);
|
||||||
|
|
||||||
// TODO(anudhyan): Also register GPU kernels for Einsum.
|
|
||||||
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
|
||||||
TF_CALL_float(REGISTER_CPU);
|
TF_CALL_float(REGISTER_CPU);
|
||||||
TF_CALL_double(REGISTER_CPU);
|
TF_CALL_double(REGISTER_CPU);
|
||||||
TF_CALL_complex64(REGISTER_CPU);
|
TF_CALL_complex64(REGISTER_CPU);
|
||||||
TF_CALL_complex128(REGISTER_CPU);
|
TF_CALL_complex128(REGISTER_CPU);
|
||||||
#undef REGISTER_CPU
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define REGISTER_GPU(TYPE) REGISTER_EINSUM(GPU, TYPE)
|
||||||
|
TF_CALL_float(REGISTER_GPU);
|
||||||
|
TF_CALL_double(REGISTER_GPU);
|
||||||
|
TF_CALL_complex64(REGISTER_GPU);
|
||||||
|
TF_CALL_complex128(REGISTER_GPU);
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#undef REGISTER_EINSUM
|
#undef REGISTER_EINSUM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
48
tensorflow/core/kernels/einsum_op.h
Normal file
48
tensorflow/core/kernels/einsum_op.h
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename Device, typename T, int N>
|
||||||
|
struct StrideFunctor {
|
||||||
|
void operator()(const Device& d, typename TTypes<T, N>::ConstTensor input,
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, N>& strides,
|
||||||
|
typename TTypes<T, N>::Tensor output) {
|
||||||
|
output.device(d) = input.stride(strides);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T, int N>
|
||||||
|
struct InflateFunctor {
|
||||||
|
void operator()(const Device& d, typename TTypes<T, N>::ConstTensor input,
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, N>& strides,
|
||||||
|
typename TTypes<T, N>::Tensor output) {
|
||||||
|
output.device(d) = input.inflate(strides);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
46
tensorflow/core/kernels/einsum_op_gpu.cu.cc
Normal file
46
tensorflow/core/kernels/einsum_op_gpu.cu.cc
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/einsum_op.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS_NDIM(T, NDIM) \
|
||||||
|
template struct functor::StrideFunctor<Eigen::GpuDevice, T, NDIM>; \
|
||||||
|
template struct functor::InflateFunctor<Eigen::GpuDevice, T, NDIM>;
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
|
DECLARE_GPU_SPECS_NDIM(T, 1); \
|
||||||
|
DECLARE_GPU_SPECS_NDIM(T, 2); \
|
||||||
|
DECLARE_GPU_SPECS_NDIM(T, 3); \
|
||||||
|
DECLARE_GPU_SPECS_NDIM(T, 4); \
|
||||||
|
DECLARE_GPU_SPECS_NDIM(T, 5); \
|
||||||
|
DECLARE_GPU_SPECS_NDIM(T, 6);
|
||||||
|
|
||||||
|
TF_CALL_float(DECLARE_GPU_SPECS);
|
||||||
|
TF_CALL_double(DECLARE_GPU_SPECS);
|
||||||
|
TF_CALL_complex64(DECLARE_GPU_SPECS);
|
||||||
|
TF_CALL_complex128(DECLARE_GPU_SPECS);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPECS_NDIM
|
||||||
|
#undef DECLARE_GPU_SPECS
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
44
tensorflow/core/kernels/reduction_ops_common_gpu.h
Normal file
44
tensorflow/core/kernels/reduction_ops_common_gpu.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
|
||||||
|
|
||||||
|
#if !GOOGLE_CUDA
|
||||||
|
#error This file must only be included when building with Cuda support
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename Reducer>
|
||||||
|
struct ReduceFunctor<Eigen::GpuDevice, Reducer> {
|
||||||
|
template <typename OUT_T, typename IN_T, typename ReductionAxes>
|
||||||
|
static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
|
||||||
|
const ReductionAxes& reduction_axes,
|
||||||
|
const Reducer& reducer);
|
||||||
|
|
||||||
|
template <typename OUT_T>
|
||||||
|
static void FillIdentity(const Eigen::GpuDevice& d, OUT_T out,
|
||||||
|
const Reducer& reducer);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
|
Loading…
Reference in New Issue
Block a user