Adding ROCm support for the mirror_pad op
This commit is contained in:
parent
3e0a311fa0
commit
2986a2139e
@ -198,7 +198,7 @@ TF_CALL_POD_TYPES(REGISTER_KERNEL);
|
|||||||
TF_CALL_string(REGISTER_KERNEL);
|
TF_CALL_string(REGISTER_KERNEL);
|
||||||
#undef REGISTER_KERNEL
|
#undef REGISTER_KERNEL
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
namespace functor {
|
namespace functor {
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
#define DECLARE_GPU_SPEC(T, Tpaddings, i) \
|
#define DECLARE_GPU_SPEC(T, Tpaddings, i) \
|
||||||
@ -243,7 +243,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// Gradient op.
|
// Gradient op.
|
||||||
template <typename Device, typename T, typename Tpaddings>
|
template <typename Device, typename T, typename Tpaddings>
|
||||||
@ -404,7 +404,7 @@ TF_CALL_NUMBER_TYPES(DECLARE_CPU_SPECS);
|
|||||||
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
|
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||||
#undef REGISTER_KERNEL
|
#undef REGISTER_KERNEL
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
namespace functor {
|
namespace functor {
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
#define DECLARE_GPU_SPEC(T, Tpaddings, k) \
|
#define DECLARE_GPU_SPEC(T, Tpaddings, k) \
|
||||||
@ -450,6 +450,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -52,4 +52,4 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
Reference in New Issue
Block a user