diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index 567e7c075e0..f1d3a77e3f0 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -23,6 +23,10 @@ limitations under the License. #include "cuda/include/cudnn.h" #endif +#if TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#endif + #ifdef EIGEN_USE_LIBXSMM #include "include/libxsmm.h" #endif @@ -109,6 +113,36 @@ DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) { strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); + +#elif TENSORFLOW_USE_ROCM + hipDeviceProp_t properties; + hipError_t error = + hipGetDeviceProperties(&properties, platform_gpu_id.value()); + if (error != hipSuccess) { + device.set_type("UNKNOWN"); + LOG(ERROR) << "Failed to get device properties, error code: " << error; + return device; + } + + // ROCM TODO review if numbers here are valid + device.set_vendor("Advanced Micro Devices, Inc"); + device.set_model(properties.name); + device.set_frequency(properties.clockRate * 1e-3); + device.set_num_cores(properties.multiProcessorCount); + device.set_num_registers(properties.regsPerBlock); + device.set_l1_cache_size(16 * 1024); + device.set_l2_cache_size(properties.l2CacheSize); + device.set_l3_cache_size(0); + device.set_shared_memory_size_per_multiprocessor( + properties.maxSharedMemoryPerMultiProcessor); + device.set_memory_size(properties.totalGlobalMem); + // 8 is the number of bits per byte. 2 is accounted for + // double data rate (DDR). + device.set_bandwidth(properties.memoryBusWidth / 8 * + properties.memoryClockRate * 2); + + (*device.mutable_environment())["architecture"] = + strings::StrCat("gfx", properties.gcnArch); #endif return device; diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc index d3fc0e1461b..fb00e1bb08c 100644 --- a/tensorflow/core/kernels/softplus_op.cc +++ b/tensorflow/core/kernels/softplus_op.cc @@ -87,7 +87,7 @@ void SoftplusGradOp::OperateNoTemplate(OpKernelContext* context, TF_CALL_FLOAT_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -119,6 +119,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/softplus_op_gpu.cu.cc b/tensorflow/core/kernels/softplus_op_gpu.cu.cc index 8df734588b8..900df277a5b 100644 --- a/tensorflow/core/kernels/softplus_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softplus_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -37,4 +37,4 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/stream_executor/gpu/gpu_types.h b/tensorflow/stream_executor/gpu/gpu_types.h index c69177d0760..64a6e5e5efc 100644 --- a/tensorflow/stream_executor/gpu/gpu_types.h +++ b/tensorflow/stream_executor/gpu/gpu_types.h @@ -20,6 +20,8 @@ limitations under the License. #if TENSORFLOW_USE_ROCM +#define __HIP_DISABLE_CPP_FUNCTIONS__ + #include "rocm/include/hip/hip_complex.h" #include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hiprand/hiprand.h" diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index a2a9ad67534..6ed7480ff0c 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -158,27 +158,27 @@ port::StatusOr GetCudnnDsoHandle() { port::StatusOr GetRocblasDsoHandle() { static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); - return result; + return *result; } port::StatusOr GetMiopenDsoHandle() { static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); - return result; + return *result; } port::StatusOr GetRocfftDsoHandle() { static auto result = new auto(DsoLoader::GetRocfftDsoHandle()); - return result; + return *result; } port::StatusOr GetRocrandDsoHandle() { static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); - return result; + return *result; } port::StatusOr GetHipDsoHandle() { static auto result = new auto(DsoLoader::GetHipDsoHandle()); - return result; + return *result; } } // namespace CachedDsoLoader diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index 8c452a31d36..902d8f98ee0 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -232,7 +232,6 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/stream_executor:dnn", "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:logging_proto_cc", "//tensorflow/stream_executor:plugin_registry", "//tensorflow/stream_executor:scratch_allocator", "//tensorflow/stream_executor:stream_executor_pimpl_header", diff --git a/tensorflow/stream_executor/rocm/rocm_driver.cc b/tensorflow/stream_executor/rocm/rocm_driver.cc index 1b0b91426aa..1c44e9f814f 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver.cc +++ b/tensorflow/stream_executor/rocm/rocm_driver.cc @@ -500,7 +500,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { hipDeviceptr_t location, uint8 value, size_t size) { ScopedActivateContext activation{context}; - hipError_t res = tensorflow::wrap::hipMemset(location, value, size); + hipError_t res = tensorflow::wrap::hipMemsetD8(location, value, size); if (res != hipSuccess) { LOG(ERROR) << "failed to memset memory: " << ToString(res); return false; @@ -514,15 +514,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { size_t uint32_count) { ScopedActivateContext activation{context}; void* pointer = absl::bit_cast(location); - unsigned char valueC = static_cast(value); - uint32_t value32 = (valueC << 24) | (valueC << 16) | (valueC << 8) | (valueC); - if (value32 != value) { - // mismatch indicates case where hipMemsetAsyc can't emulate hipMemSetD32 - LOG(ERROR) << "failed to memset memory"; - return false; - } - hipError_t res = tensorflow::wrap::hipMemset(pointer, static_cast(value), - uint32_count * 4); + hipError_t res = tensorflow::wrap::hipMemsetD32(pointer, value, uint32_count); if (res != hipSuccess) { LOG(ERROR) << "failed to memset memory: " << ToString(res); return false; @@ -553,17 +545,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream) { ScopedActivateContext activation{context}; void* pointer = absl::bit_cast(location); - - // FIXME - need to set a 32-bit value here - unsigned char valueC = static_cast(value); - uint32_t value32 = (valueC << 24) | (valueC << 16) | (valueC << 8) | (valueC); - if (value32 != value) { - // mismatch indicates case where hipMemsetAsyc can't emulate hipMemSetD32 - LOG(ERROR) << "failed to memset memory"; - return false; - } - hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value, - uint32_count * 4, stream); + hipError_t res = + tensorflow::wrap::hipMemsetD32Async(pointer, value, uint32_count, stream); if (res != hipSuccess) { LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res); return false; @@ -671,7 +654,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { uint64 bytes) { ScopedActivateContext activated{context}; hipDeviceptr_t result = 0; - hipError_t res = tensorflow::wrap::hipMallocVanilla(&result, bytes); + hipError_t res = tensorflow::wrap::hipMalloc(&result, bytes); if (res != hipSuccess) { LOG(ERROR) << "failed to allocate " << port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes @@ -717,8 +700,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { ScopedActivateContext activation{context}; void* host_mem = nullptr; // "Portable" memory is visible to all ROCM contexts. Safe for our use model. - hipError_t res = tensorflow::wrap::hipHostMallocVanilla( - &host_mem, bytes, hipHostMallocPortable); + hipError_t res = + tensorflow::wrap::hipHostMalloc(&host_mem, bytes, hipHostMallocPortable); if (res != hipSuccess) { LOG(ERROR) << "failed to alloc " << bytes << " bytes on host: " << ToString(res); diff --git a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h index 27495c2cbc0..ba803edaafb 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h +++ b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h @@ -20,6 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ +#define __HIP_DISABLE_CPP_FUNCTIONS__ + #include "rocm/include/hip/hip_runtime.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/platform/dso_loader.h" @@ -48,21 +50,6 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -// hipMalloc and hipHostMalloc are defined as funtion templates in the -// HIP header files, and hence their names get mangled and the attempt -// to resolve their name when trying to dynamically load them will fail -// Updating the HIP header files to make them C functions is underway. -// Until that change flows through, we will workaround the issue by -// creating dummy wrappers for them here - -hipError_t hipMallocVanilla(void** ptr, size_t size) { - return hipErrorNotInitialized; -} - -hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) { - return hipErrorNotInitialized; -} - #define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ template \ auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \ @@ -107,9 +94,11 @@ hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) { __macro(hipGetDeviceCount) \ __macro(hipGetDeviceProperties) \ __macro(hipHostFree) \ + __macro(hipHostMalloc) \ __macro(hipHostRegister) \ __macro(hipHostUnregister) \ __macro(hipInit) \ + __macro(hipMalloc) \ __macro(hipMemGetAddressRange) \ __macro(hipMemGetInfo) \ __macro(hipMemcpyDtoD) \ @@ -119,7 +108,10 @@ hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) { __macro(hipMemcpyHtoD) \ __macro(hipMemcpyHtoDAsync) \ __macro(hipMemset) \ + __macro(hipMemsetD32) \ + __macro(hipMemsetD8) \ __macro(hipMemsetAsync) \ + __macro(hipMemsetD32Async) \ __macro(hipModuleGetFunction) \ __macro(hipModuleGetGlobal) \ __macro(hipModuleLaunchKernel) \ diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index bbc1fbaacfa..28c189e9538 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1234,6 +1234,7 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs): "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured(cuda_deps + [ # rocm_header placeholder + # rocm_header placeholder ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), **kwargs