Adding ROCm support for the batch_matmul op
This commit is contained in:
parent
1979135fb7
commit
24974ed8bb
tensorflow/core/kernels
@ -20,9 +20,9 @@ namespace tensorflow {
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -42,9 +42,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -248,22 +248,22 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
}
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
|
||||
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
|
||||
se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
|
||||
se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
|
||||
se::DeviceMemory<T> typed(wrapped);
|
||||
return typed;
|
||||
}
|
||||
|
||||
class CublasScratchAllocator : public se::ScratchAllocator {
|
||||
class BlasScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
using Stream = se::Stream;
|
||||
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
|
||||
|
||||
CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
|
||||
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
|
||||
|
||||
int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
|
||||
|
||||
@ -357,7 +357,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
|
||||
typedef Scalar Coefficient;
|
||||
|
||||
// Cublas does
|
||||
// Blas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
@ -406,7 +406,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
CublasScratchAllocator scratch_allocator(context);
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
@ -493,7 +493,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
|
||||
typedef float Coefficient;
|
||||
|
||||
// Cublas does
|
||||
// Blas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
@ -517,7 +517,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
", k=", k));
|
||||
}
|
||||
} else {
|
||||
CublasScratchAllocator scratch_allocator(context);
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
@ -537,7 +537,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename Scalar>
|
||||
|
@ -27,11 +27,11 @@ TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
|
||||
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
|
||||
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL);
|
||||
|
Loading…
Reference in New Issue
Block a user