Adding ROCm support for the "matmul" op
This commit is contained in:
parent
fdd04bd83d
commit
d29a326f46
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
@ -214,6 +214,6 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
|
||||
|
@ -26,9 +26,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/matmul_autotune.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#endif
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -111,11 +113,11 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchMatMulBase {
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
typedef se::blas::AlgorithmType AlgorithmType;
|
||||
#else
|
||||
typedef int64 AlgorithmType;
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
static void launch(
|
||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
||||
@ -154,7 +156,7 @@ template <typename T, bool USE_CUBLAS>
|
||||
struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace {
|
||||
|
||||
@ -433,7 +435,7 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
class MatMulOp : public OpKernel {
|
||||
@ -622,13 +624,13 @@ TF_CALL_complex64(REGISTER_CPU);
|
||||
TF_CALL_complex128(REGISTER_CPU);
|
||||
#endif // INTEL_MKL && ENABLE_MKL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL(T) \
|
||||
|
@ -58,7 +58,7 @@ struct MatMulFunctor {
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Encapsulate all the shape information that is used in matmul operations.
|
||||
class MatmulParameters {
|
||||
public:
|
||||
@ -117,7 +117,7 @@ class MatmulParameters {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user