Adding ROCm support for the mirror_pad op
This commit is contained in:
		
							parent
							
								
									3e0a311fa0
								
							
						
					
					
						commit
						2986a2139e
					
				| @ -198,7 +198,7 @@ TF_CALL_POD_TYPES(REGISTER_KERNEL); | ||||
| TF_CALL_string(REGISTER_KERNEL); | ||||
| #undef REGISTER_KERNEL | ||||
| 
 | ||||
| #if GOOGLE_CUDA | ||||
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM | ||||
| namespace functor { | ||||
| // Forward declarations of the functor specializations for GPU.
 | ||||
| #define DECLARE_GPU_SPEC(T, Tpaddings, i)                     \ | ||||
| @ -243,7 +243,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); | ||||
| 
 | ||||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); | ||||
| #undef REGISTER_GPU_KERNEL | ||||
| #endif  // GOOGLE_CUDA
 | ||||
| #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | ||||
| 
 | ||||
| // Gradient op.
 | ||||
| template <typename Device, typename T, typename Tpaddings> | ||||
| @ -404,7 +404,7 @@ TF_CALL_NUMBER_TYPES(DECLARE_CPU_SPECS); | ||||
| TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); | ||||
| #undef REGISTER_KERNEL | ||||
| 
 | ||||
| #if GOOGLE_CUDA | ||||
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM | ||||
| namespace functor { | ||||
| // Forward declarations of the functor specializations for GPU.
 | ||||
| #define DECLARE_GPU_SPEC(T, Tpaddings, k)                     \ | ||||
| @ -450,6 +450,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); | ||||
| 
 | ||||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); | ||||
| #undef REGISTER_GPU_KERNEL | ||||
| #endif  // GOOGLE_CUDA
 | ||||
| #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
|  | ||||
| @ -13,7 +13,7 @@ See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #if GOOGLE_CUDA | ||||
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM | ||||
| 
 | ||||
| #define EIGEN_USE_GPU | ||||
| 
 | ||||
| @ -52,4 +52,4 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
| 
 | ||||
| #endif  // GOOGLE_CUDA
 | ||||
| #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user