Adding ROCm support for the "matmul" op

This commit is contained in:
Deven Desai 2019-05-13 16:47:01 +00:00
parent fdd04bd83d
commit d29a326f46
3 changed files with 13 additions and 11 deletions

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <unordered_map>
@ -214,6 +214,6 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_

View File

@ -26,9 +26,11 @@ limitations under the License.
#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
@ -111,11 +113,11 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
template <typename Device, typename T>
struct LaunchMatMulBase {
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef se::blas::AlgorithmType AlgorithmType;
#else
typedef int64 AlgorithmType;
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
@ -154,7 +156,7 @@ template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace {
@ -433,7 +435,7 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
}
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
@ -622,13 +624,13 @@ TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
TF_CALL_half(REGISTER_GPU);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL(T) \

View File

@ -58,7 +58,7 @@ struct MatMulFunctor {
} // end namespace functor
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Encapsulate all the shape information that is used in matmul operations.
class MatmulParameters {
public:
@ -117,7 +117,7 @@ class MatmulParameters {
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // end namespace tensorflow