Adding ROCm support for the batch_matmul op

This commit is contained in:
Deven Desai 2019-07-02 19:18:21 +00:00
parent 1979135fb7
commit 24974ed8bb
3 changed files with 16 additions and 16 deletions

View File

@ -20,9 +20,9 @@ namespace tensorflow {
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -42,9 +42,9 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
#endif
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
@ -248,22 +248,22 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
}
};
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace {
template <typename T>
se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
se::DeviceMemory<T> typed(wrapped);
return typed;
}
class CublasScratchAllocator : public se::ScratchAllocator {
class BlasScratchAllocator : public se::ScratchAllocator {
public:
using Stream = se::Stream;
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
@ -357,7 +357,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
typedef Scalar Coefficient;
// Cublas does
// Blas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
@ -406,7 +406,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
}
}
} else {
CublasScratchAllocator scratch_allocator(context);
BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status =
stream
->ThenBlasGemmBatchedWithScratch(
@ -493,7 +493,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
typedef float Coefficient;
// Cublas does
// Blas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
@ -517,7 +517,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
", k=", k));
}
} else {
CublasScratchAllocator scratch_allocator(context);
BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status =
stream
->ThenBlasGemmBatchedWithScratch(
@ -537,7 +537,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
}
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
template <typename Scalar>

View File

@ -27,11 +27,11 @@ TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL);