diff --git a/tensorflow/stream_executor/rocm/rocm_driver.cc b/tensorflow/stream_executor/rocm/rocm_driver.cc index 94feef06d86..5b79aa0f92e 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver.cc +++ b/tensorflow/stream_executor/rocm/rocm_driver.cc @@ -500,7 +500,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { hipDeviceptr_t location, uint8 value, size_t size) { ScopedActivateContext activation{context}; - hipError_t res = tensorflow::wrap::hipMemset(location, value, size); + hipError_t res = tensorflow::wrap::hipMemsetD8(location, value, size); if (res != hipSuccess) { LOG(ERROR) << "failed to memset memory: " << ToString(res); return false; @@ -514,15 +514,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { size_t uint32_count) { ScopedActivateContext activation{context}; void* pointer = absl::bit_cast(location); - unsigned char valueC = static_cast(value); - uint32_t value32 = (valueC << 24) | (valueC << 16) | (valueC << 8) | (valueC); - if (value32 != value) { - // mismatch indicates case where hipMemsetAsyc can't emulate hipMemSetD32 - LOG(ERROR) << "failed to memset memory"; - return false; - } - hipError_t res = tensorflow::wrap::hipMemset(pointer, static_cast(value), - uint32_count * 4); + hipError_t res = tensorflow::wrap::hipMemsetD32(pointer, value, uint32_count); if (res != hipSuccess) { LOG(ERROR) << "failed to memset memory: " << ToString(res); return false; @@ -553,17 +545,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream) { ScopedActivateContext activation{context}; void* pointer = absl::bit_cast(location); - - // FIXME - need to set a 32-bit value here - unsigned char valueC = static_cast(value); - uint32_t value32 = (valueC << 24) | (valueC << 16) | (valueC << 8) | (valueC); - if (value32 != value) { - // mismatch indicates case where hipMemsetAsyc can't emulate hipMemSetD32 - LOG(ERROR) << "failed to memset memory"; - return false; - } - hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value, - uint32_count * 4, stream); + hipError_t res = + tensorflow::wrap::hipMemsetD32Async(pointer, value, uint32_count, stream); if (res != hipSuccess) { LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res); return false; diff --git a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h index c855bfb36a8..ba803edaafb 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h +++ b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h @@ -108,7 +108,10 @@ namespace wrap { __macro(hipMemcpyHtoD) \ __macro(hipMemcpyHtoDAsync) \ __macro(hipMemset) \ + __macro(hipMemsetD32) \ + __macro(hipMemsetD8) \ __macro(hipMemsetAsync) \ + __macro(hipMemsetD32Async) \ __macro(hipModuleGetFunction) \ __macro(hipModuleGetGlobal) \ __macro(hipModuleLaunchKernel) \