Adding ROCm support for the random ops
This commit is contained in:
parent
fff00129e1
commit
76f7070031
@ -384,7 +384,7 @@ TF_CALL_int64(REGISTER_INT);
|
||||
#undef REGISTER
|
||||
#undef REGISTER_INT
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -435,7 +435,7 @@ TF_CALL_int64(REGISTER_INT);
|
||||
#undef REGISTER
|
||||
#undef REGISTER_INT
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
||||
|
@ -42,7 +42,7 @@ struct FillPhiloxRandom<CPUDevice, Distribution> {
|
||||
Distribution dist);
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
// Declares the partially GPU-specialized functor struct.
|
||||
template <class Distribution>
|
||||
@ -52,7 +52,7 @@ struct FillPhiloxRandom<GPUDevice, Distribution> {
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -68,4 +68,4 @@ template struct FillPhiloxRandom<
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
@ -222,14 +222,14 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
||||
block_size;
|
||||
|
||||
TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
||||
size, dist));
|
||||
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
||||
size, dist));
|
||||
}
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // defined(__CUDACC__)
|
||||
#endif // defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
||||
|
Loading…
Reference in New Issue
Block a user