Adding ROCm support for the random ops
This commit is contained in:
parent
fff00129e1
commit
76f7070031
@ -384,7 +384,7 @@ TF_CALL_int64(REGISTER_INT);
|
|||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
#undef REGISTER_INT
|
#undef REGISTER_INT
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define REGISTER(TYPE) \
|
#define REGISTER(TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
@ -435,7 +435,7 @@ TF_CALL_int64(REGISTER_INT);
|
|||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
#undef REGISTER_INT
|
#undef REGISTER_INT
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ struct FillPhiloxRandom<CPUDevice, Distribution> {
|
|||||||
Distribution dist);
|
Distribution dist);
|
||||||
};
|
};
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
// Declares the partially GPU-specialized functor struct.
|
// Declares the partially GPU-specialized functor struct.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
@ -52,7 +52,7 @@ struct FillPhiloxRandom<GPUDevice, Distribution> {
|
|||||||
typename Distribution::ResultElementType* data, int64 size,
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
Distribution dist);
|
Distribution dist);
|
||||||
};
|
};
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#if TENSORFLOW_USE_SYCL
|
#if TENSORFLOW_USE_SYCL
|
||||||
typedef Eigen::SyclDevice SYCLDevice;
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -68,4 +68,4 @@ template struct FillPhiloxRandom<
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
@ -222,14 +222,14 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
|||||||
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
||||||
block_size;
|
block_size;
|
||||||
|
|
||||||
TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
num_blocks, block_size, 0, d.stream(), gen, data,
|
||||||
size, dist));
|
size, dist));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // defined(__CUDACC__)
|
#endif // defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user