Adding ROCm support for the batch_matmul op

This commit is contained in:
Deven Desai 2019-07-02 19:18:21 +00:00
parent 1979135fb7
commit 24974ed8bb
3 changed files with 16 additions and 16 deletions

View File

@ -20,9 +20,9 @@ namespace tensorflow {
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU); TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU); TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -42,9 +42,9 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_contraction_kernel.h" #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
#endif #endif
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow { namespace tensorflow {
@ -248,22 +248,22 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
} }
}; };
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace { namespace {
template <typename T> template <typename T>
se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); se::DeviceMemoryBase wrapped(const_cast<T*>(gpu_memory));
se::DeviceMemory<T> typed(wrapped); se::DeviceMemory<T> typed(wrapped);
return typed; return typed;
} }
class CublasScratchAllocator : public se::ScratchAllocator { class BlasScratchAllocator : public se::ScratchAllocator {
public: public:
using Stream = se::Stream; using Stream = se::Stream;
using DeviceMemoryBytes = se::DeviceMemory<uint8>; using DeviceMemoryBytes = se::DeviceMemory<uint8>;
CublasScratchAllocator(OpKernelContext* context) : context_(context) {} BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; } int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
@ -357,7 +357,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
typedef Scalar Coefficient; typedef Scalar Coefficient;
// Cublas does // Blas does
// C = A x B // C = A x B
// where A, B and C are assumed to be in column major. // where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute // We want the output to be in row-major, so we can compute
@ -406,7 +406,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
} }
} }
} else { } else {
CublasScratchAllocator scratch_allocator(context); BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status = bool blas_launch_status =
stream stream
->ThenBlasGemmBatchedWithScratch( ->ThenBlasGemmBatchedWithScratch(
@ -493,7 +493,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
typedef float Coefficient; typedef float Coefficient;
// Cublas does // Blas does
// C = A x B // C = A x B
// where A, B and C are assumed to be in column major. // where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute // We want the output to be in row-major, so we can compute
@ -517,7 +517,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
", k=", k)); ", k=", k));
} }
} else { } else {
CublasScratchAllocator scratch_allocator(context); BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status = bool blas_launch_status =
stream stream
->ThenBlasGemmBatchedWithScratch( ->ThenBlasGemmBatchedWithScratch(
@ -537,7 +537,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
} }
}; };
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
template <typename Scalar> template <typename Scalar>

View File

@ -27,11 +27,11 @@ TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int64(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU); TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU); TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL); TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL);