diff --git a/tensorflow/core/kernels/extract_image_patches_op.cc b/tensorflow/core/kernels/extract_image_patches_op.cc index 9306eccf9f0..7192fec37e6 100644 --- a/tensorflow/core/kernels/extract_image_patches_op.cc +++ b/tensorflow/core/kernels/extract_image_patches_op.cc @@ -130,7 +130,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER); #undef REGISTER -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { @@ -160,6 +160,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER); #undef REGISTER -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc index 50159282ff1..465b7acd475 100644 --- a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc +++ b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -35,4 +35,4 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER); } // end namespace functor } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM