Adding ROCm support for the batch_norm op
This commit is contained in:
parent
b211c7a053
commit
b0597e4175
@ -175,7 +175,7 @@ TF_CALL_float(REGISTER_KERNEL);
|
||||
TF_CALL_double(REGISTER_KERNEL);
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
@ -206,7 +206,7 @@ TF_CALL_half(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_float(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_KERNEL(T) \
|
||||
@ -231,7 +231,7 @@ TF_CALL_float(REGISTER_KERNEL);
|
||||
TF_CALL_double(REGISTER_KERNEL);
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
@ -265,7 +265,7 @@ TF_CALL_half(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_float(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_KERNEL(T) \
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -31,4 +31,4 @@ template struct functor::BatchNormGrad<GPUDevice, Eigen::half>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
Loading…
Reference in New Issue
Block a user