Merge pull request #46232 from ROCmSoftwarePlatform:google_upstream_rocm_use_gcn_arch_name
PiperOrigin-RevId: 351341559 Change-Id: I1477af5b453fbd76c9df93fdeed8a16899ff8b3a
This commit is contained in:
commit
5ac247e18b
@ -120,8 +120,9 @@ class GpuKernelToBlobPass
|
|||||||
|
|
||||||
std::string libdevice_dir = tensorflow::RocdlRoot();
|
std::string libdevice_dir = tensorflow::RocdlRoot();
|
||||||
auto llvm_module_copy = llvm::CloneModule(*llvmModule);
|
auto llvm_module_copy = llvm::CloneModule(*llvmModule);
|
||||||
|
xla::gpu::GpuVersion gpu_version{std::make_pair(arch, arch_str)};
|
||||||
auto hsaco_or = xla::gpu::amdgpu::CompileToHsaco(
|
auto hsaco_or = xla::gpu::amdgpu::CompileToHsaco(
|
||||||
llvm_module_copy.get(), arch, config, libdevice_dir);
|
llvm_module_copy.get(), gpu_version, config, libdevice_dir);
|
||||||
if (!hsaco_or.ok()) {
|
if (!hsaco_or.ok()) {
|
||||||
return InternalError("Failure when generating HSACO");
|
return InternalError("Failure when generating HSACO");
|
||||||
}
|
}
|
||||||
|
@ -100,8 +100,14 @@ GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
|||||||
<< "Couldn't get AMDGPU ISA version for device; assuming gfx803.";
|
<< "Couldn't get AMDGPU ISA version for device; assuming gfx803.";
|
||||||
isa_version = 803;
|
isa_version = 803;
|
||||||
}
|
}
|
||||||
|
std::string gcn_arch_name =
|
||||||
|
stream_exec->GetDeviceDescription().rocm_amdgpu_gcn_arch_name();
|
||||||
|
if (gcn_arch_name == stream_exec->GetDeviceDescription().kUndefinedString) {
|
||||||
|
LOG(WARNING) << "Couldn't get AMDGPU GCN Arch for device; assuming gfx803.";
|
||||||
|
gcn_arch_name = "gfx803";
|
||||||
|
}
|
||||||
|
|
||||||
return isa_version;
|
return std::make_pair(isa_version, gcn_arch_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||||
|
@ -101,10 +101,11 @@ Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
|
|||||||
int stream_isa_version;
|
int stream_isa_version;
|
||||||
main_stream->parent()->GetDeviceDescription().rocm_amdgpu_isa_version(
|
main_stream->parent()->GetDeviceDescription().rocm_amdgpu_isa_version(
|
||||||
&stream_isa_version);
|
&stream_isa_version);
|
||||||
GpuVersion amd_isa_version = stream_isa_version;
|
int gpu_exec_isa_version =
|
||||||
TF_RET_CHECK(amd_isa_version == gpu_version_)
|
absl::get<std::pair<int, std::string>>(gpu_version_).first;
|
||||||
<< "AMDGPU GCN ISA version mismatch; expected {"
|
TF_RET_CHECK(stream_isa_version == gpu_exec_isa_version)
|
||||||
<< absl::get<int>(gpu_version_) << ", but was " << stream_isa_version;
|
<< "AMDGPU GCN ISA version mismatch; expected {" << gpu_exec_isa_version
|
||||||
|
<< ", but was " << stream_isa_version;
|
||||||
} else if (platform_kind == stream_executor::PlatformKind::kCuda) {
|
} else if (platform_kind == stream_executor::PlatformKind::kCuda) {
|
||||||
std::pair<int, int> stream_compute_compatibility;
|
std::pair<int, int> stream_compute_compatibility;
|
||||||
main_stream->parent()->GetDeviceDescription().cuda_compute_capability(
|
main_stream->parent()->GetDeviceDescription().cuda_compute_capability(
|
||||||
|
@ -21,10 +21,19 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
// GpuVersion is used to abstract Gpu hardware version. On Cuda platform,
|
// GpuVersion is used to abstract Gpu hardware version.
|
||||||
// it comprises a pair of integers denoting major and minor version.
|
//
|
||||||
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
|
// On Cuda platform, it comprises of an <int, int> pair
|
||||||
using GpuVersion = absl::variant<std::pair<int, int>, int>;
|
// denoting major and minor version.
|
||||||
|
//
|
||||||
|
// On ROCm platform, it comprises of an <int, string> pair
|
||||||
|
// the int has the contents of the hipDeviceProp_t::gcnArchValue field.
|
||||||
|
// the string has the contents of the hipDeviceProp_t::gcnArchName field.
|
||||||
|
// The string contains all the information needed to create an exact LLVM
|
||||||
|
// AMDGPUTarget corresopnding the AMDGPU device it represents, the int value
|
||||||
|
// by itself is not sufficient for this purpose
|
||||||
|
using GpuVersion =
|
||||||
|
absl::variant<std::pair<int, int>, std::pair<int, std::string>>;
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -601,7 +601,7 @@ static std::vector<string> GetROCDLPaths(int amdgpu_version,
|
|||||||
struct HsacoCacheEntry {
|
struct HsacoCacheEntry {
|
||||||
uint64 hash;
|
uint64 hash;
|
||||||
std::string ir;
|
std::string ir;
|
||||||
int gfx;
|
std::string gfx;
|
||||||
std::vector<uint8> hsaco;
|
std::vector<uint8> hsaco;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -613,16 +613,16 @@ struct HsacoCache {
|
|||||||
int hit_count = 0;
|
int hit_count = 0;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static bool Find(const std::string& ir, uint64_t& hash, int gfx,
|
static bool Find(const std::string& ir, uint64_t& hash,
|
||||||
std::vector<uint8>& hsaco);
|
const std::string& gfx, std::vector<uint8>& hsaco);
|
||||||
static void Add(const std::string& ir, uint64_t hash, int gfx,
|
static void Add(const std::string& ir, uint64_t hash, const std::string& gfx,
|
||||||
const std::vector<uint8>& hsaco);
|
const std::vector<uint8>& hsaco);
|
||||||
};
|
};
|
||||||
|
|
||||||
static HsacoCache g_hsacoCache;
|
static HsacoCache g_hsacoCache;
|
||||||
|
|
||||||
bool HsacoCache::Find(const std::string& ir, uint64_t& hash, int gfx,
|
bool HsacoCache::Find(const std::string& ir, uint64_t& hash,
|
||||||
std::vector<uint8>& hsaco) {
|
const std::string& gfx, std::vector<uint8>& hsaco) {
|
||||||
std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
|
std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
|
||||||
hash = std::hash<std::string>{}(ir);
|
hash = std::hash<std::string>{}(ir);
|
||||||
bool hit = false;
|
bool hit = false;
|
||||||
@ -642,8 +642,8 @@ bool HsacoCache::Find(const std::string& ir, uint64_t& hash, int gfx,
|
|||||||
return hit;
|
return hit;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HsacoCache::Add(const std::string& ir, uint64_t hash, int gfx,
|
void HsacoCache::Add(const std::string& ir, uint64_t hash,
|
||||||
const std::vector<uint8>& hsaco) {
|
const std::string& gfx, const std::vector<uint8>& hsaco) {
|
||||||
std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
|
std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
|
||||||
g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1);
|
g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1);
|
||||||
g_hsacoCache.cache.back().ir = ir;
|
g_hsacoCache.cache.back().ir = ir;
|
||||||
@ -787,29 +787,79 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
|
|||||||
const HloModuleConfig& hlo_module_config,
|
const HloModuleConfig& hlo_module_config,
|
||||||
const string& device_bitcode_dir_path) {
|
const string& device_bitcode_dir_path) {
|
||||||
// Link the input module with ROCDL.
|
// Link the input module with ROCDL.
|
||||||
auto amdgpu_version = absl::get_if<int>(&gpu_version);
|
auto amdgpu_version = absl::get_if<std::pair<int, std::string>>(&gpu_version);
|
||||||
if (!amdgpu_version) {
|
if (!amdgpu_version) {
|
||||||
return xla::InternalError(
|
return xla::InternalError(
|
||||||
"Incompatible AMD GCN ISA version was specified.");
|
"Incompatible AMD GCN ISA version was specified.");
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(LinkROCDLIfNecessary(module, amdgpu_version->first,
|
||||||
LinkROCDLIfNecessary(module, *amdgpu_version, device_bitcode_dir_path));
|
device_bitcode_dir_path));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
|
// The following routine maps a feature token extracted from the
|
||||||
llvm::Triple target_triple, int amdgpu_version,
|
// hipDeviceProp_t::gcnArchName string, and maps it to a valid feature_str
|
||||||
const HloModuleConfig& hlo_module_config) {
|
// to be used for creating the AMDGPUTarget.
|
||||||
string feature_str = "+code-object-v3";
|
// This mapping is currently in a state of flux because TF XLA uses its
|
||||||
#if TF_ROCM_VERSION >= 30900
|
// own copy of LLVM, which is different from the LLVM version used by
|
||||||
// code-object-v3 is default, so no need to expliticitly specify it
|
// hipcc/runtime in the ROCm install. Ordinarily this is not a problem,
|
||||||
// in the feature string. Also, starting with ROCm 4.0, this feature string
|
// but right now, the LLVM version used by hipcc/runtime has "targetID"
|
||||||
// is deprecated, and we get a warning to that effect. So removing that
|
// related changes which have not yet been upstreamed (to the LLVM repo)
|
||||||
// feature string
|
// When that upstreaming happens (and TF LLVM pointer moves past the
|
||||||
|
// upstream commit), the following mapping will need to change
|
||||||
|
static std::string MapGCNArchNameTokenToFeatureStr(const std::string token) {
|
||||||
|
if (token == "sramecc+") {
|
||||||
|
return "+sram-ecc";
|
||||||
|
} else if (token == "sramecc-") {
|
||||||
|
return "-sram-ecc";
|
||||||
|
} else if (token == "xnack+") {
|
||||||
|
return "+xnack";
|
||||||
|
} else if (token == "xnack-") {
|
||||||
|
return "-xnack";
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string GetFeatureStrFromGCNArchName(
|
||||||
|
const std::string gcn_arch_name) {
|
||||||
|
std::string feature_str;
|
||||||
|
|
||||||
|
#if TF_ROCM_VERSION < 30900
|
||||||
|
// For ROCm versions older than 3.9, hardcode it to "+code-object-v3"
|
||||||
|
// This is simply to preserve how things were...nohing else
|
||||||
|
feature_str = "+code-object-v3";
|
||||||
|
#elif TF_ROCM_VERSION < 40000
|
||||||
|
// For ROCM versions 3.9 and 3.10, hardcode it to empty string
|
||||||
feature_str = "";
|
feature_str = "";
|
||||||
|
#else
|
||||||
|
// For ROCm versions 4.0 and greater, we need to specify the correct
|
||||||
|
// feature str, based on the underlying GPU HW to get max performance.
|
||||||
|
std::vector<std::string> tokens = absl::StrSplit(gcn_arch_name, ':');
|
||||||
|
std::vector<std::string> mapped_tokens;
|
||||||
|
for (auto it = tokens.begin(); it != tokens.end(); it++) {
|
||||||
|
// Skip the first token, that is the gfxNNN str
|
||||||
|
// The rest of the tokens are the feature/targetid strings
|
||||||
|
if (it != tokens.begin()) {
|
||||||
|
std::string token(*it);
|
||||||
|
std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token);
|
||||||
|
mapped_tokens.push_back(mapped_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
feature_str = absl::StrJoin(mapped_tokens, ",");
|
||||||
#endif
|
#endif
|
||||||
return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version),
|
|
||||||
|
return feature_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
|
||||||
|
llvm::Triple target_triple, GpuVersion gpu_version,
|
||||||
|
const HloModuleConfig& hlo_module_config) {
|
||||||
|
auto amdgpu_version = absl::get_if<std::pair<int, std::string>>(&gpu_version);
|
||||||
|
int gcn_arch_value = amdgpu_version->first;
|
||||||
|
std::string gcn_arch_name = amdgpu_version->second;
|
||||||
|
std::string feature_str = GetFeatureStrFromGCNArchName(gcn_arch_name);
|
||||||
|
return GetTargetMachine(target_triple, absl::StrCat("gfx", gcn_arch_value),
|
||||||
hlo_module_config, feature_str);
|
hlo_module_config, feature_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -861,13 +911,14 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
|
|||||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||||
XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
|
XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
|
||||||
|
|
||||||
auto amdgpu_version = absl::get_if<int>(&gpu_version);
|
auto amdgpu_version =
|
||||||
|
absl::get_if<std::pair<int, std::string>>(&gpu_version);
|
||||||
if (!amdgpu_version) {
|
if (!amdgpu_version) {
|
||||||
return xla::InternalError(
|
return xla::InternalError(
|
||||||
"Incompatible AMD GCN ISA version was specified.");
|
"Incompatible AMD GCN ISA version was specified.");
|
||||||
}
|
}
|
||||||
uint64_t hash;
|
uint64_t hash;
|
||||||
if (HsacoCache::Find(str, hash, *amdgpu_version, hsaco)) {
|
if (HsacoCache::Find(str, hash, amdgpu_version->second, hsaco)) {
|
||||||
VLOG(1) << "HSACO cache hit";
|
VLOG(1) << "HSACO cache hit";
|
||||||
return hsaco;
|
return hsaco;
|
||||||
}
|
}
|
||||||
@ -885,7 +936,7 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
|
|||||||
llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
|
llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
|
||||||
// Construct LLVM TargetMachine for AMDGPU.
|
// Construct LLVM TargetMachine for AMDGPU.
|
||||||
std::unique_ptr<llvm::TargetMachine> target_machine =
|
std::unique_ptr<llvm::TargetMachine> target_machine =
|
||||||
AMDGPUGetTargetMachine(default_target_triple, *amdgpu_version,
|
AMDGPUGetTargetMachine(default_target_triple, gpu_version,
|
||||||
hlo_module_config);
|
hlo_module_config);
|
||||||
|
|
||||||
// Link with ROCm-Device-Libs, and optimize the LLVM module.
|
// Link with ROCm-Device-Libs, and optimize the LLVM module.
|
||||||
@ -896,7 +947,7 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
|
|||||||
|
|
||||||
// Lower optimized LLVM module to HSA code object.
|
// Lower optimized LLVM module to HSA code object.
|
||||||
TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
|
TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
|
||||||
HsacoCache::Add(str, hash, *amdgpu_version, hsaco);
|
HsacoCache::Add(str, hash, amdgpu_version->second, hsaco);
|
||||||
}
|
}
|
||||||
return hsaco;
|
return hsaco;
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,9 @@ class GpuDummyCompiler : public GpuCompiler {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
|
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||||
|
return std::make_pair(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||||
const HloModuleConfig& module_config, llvm::Module* llvm_module,
|
const HloModuleConfig& module_config, llvm::Module* llvm_module,
|
||||||
|
@ -1770,15 +1770,11 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
|||||||
<< strings::HumanReadableNumBytes(description->memory_bandwidth())
|
<< strings::HumanReadableNumBytes(description->memory_bandwidth())
|
||||||
<< "/s";
|
<< "/s";
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
int isa_version;
|
std::string gcn_arch_name = description->rocm_amdgpu_gcn_arch_name();
|
||||||
if (!description->rocm_amdgpu_isa_version(&isa_version)) {
|
|
||||||
// Logs internally on failure.
|
|
||||||
isa_version = 0;
|
|
||||||
}
|
|
||||||
LOG(INFO) << "Found device " << i << " with properties: "
|
LOG(INFO) << "Found device " << i << " with properties: "
|
||||||
<< "\npciBusID: " << description->pci_bus_id()
|
<< "\npciBusID: " << description->pci_bus_id()
|
||||||
<< " name: " << description->name()
|
<< " name: " << description->name()
|
||||||
<< " ROCm AMD GPU ISA: gfx" << isa_version
|
<< " ROCm AMDGPU Arch: " << gcn_arch_name
|
||||||
<< "\ncoreClock: " << description->clock_rate_ghz() << "GHz"
|
<< "\ncoreClock: " << description->clock_rate_ghz() << "GHz"
|
||||||
<< " coreCount: " << description->core_count()
|
<< " coreCount: " << description->core_count()
|
||||||
<< " deviceMemorySize: "
|
<< " deviceMemorySize: "
|
||||||
|
@ -1388,6 +1388,13 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64 bytes) {
|
|||||||
"Feature not supported on CUDA platform (GetGpuISAVersion)"};
|
"Feature not supported on CUDA platform (GetGpuISAVersion)"};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ port::Status GpuDriver::GetGpuGCNArchName(
|
||||||
|
CUdevice device, std::string* gcnArchName) {
|
||||||
|
return port::Status{
|
||||||
|
port::error::INTERNAL,
|
||||||
|
"Feature not supported on CUDA platform (GetGpuGCNArchName)"};
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function that turns the integer output of cuDeviceGetAttribute to type
|
// Helper function that turns the integer output of cuDeviceGetAttribute to type
|
||||||
// T and wraps it in a StatusOr.
|
// T and wraps it in a StatusOr.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -51,6 +51,7 @@ DeviceDescription::DeviceDescription()
|
|||||||
cuda_compute_capability_major_(-1),
|
cuda_compute_capability_major_(-1),
|
||||||
cuda_compute_capability_minor_(-1),
|
cuda_compute_capability_minor_(-1),
|
||||||
rocm_amdgpu_isa_version_(-1),
|
rocm_amdgpu_isa_version_(-1),
|
||||||
|
rocm_amdgpu_gcn_arch_name_(kUndefinedString),
|
||||||
numa_node_(-1),
|
numa_node_(-1),
|
||||||
core_count_(-1),
|
core_count_(-1),
|
||||||
ecc_enabled_(false) {}
|
ecc_enabled_(false) {}
|
||||||
@ -95,6 +96,8 @@ std::unique_ptr<std::map<std::string, std::string>> DeviceDescription::ToMap()
|
|||||||
result["CUDA Compute Capability"] = absl::StrCat(
|
result["CUDA Compute Capability"] = absl::StrCat(
|
||||||
cuda_compute_capability_major_, ".", cuda_compute_capability_minor_);
|
cuda_compute_capability_major_, ".", cuda_compute_capability_minor_);
|
||||||
|
|
||||||
|
result["AMDGPU GCN Arch Name"] = absl::StrCat(rocm_amdgpu_gcn_arch_name_);
|
||||||
|
|
||||||
result["NUMA Node"] = absl::StrCat(numa_node());
|
result["NUMA Node"] = absl::StrCat(numa_node());
|
||||||
result["Core Count"] = absl::StrCat(core_count());
|
result["Core Count"] = absl::StrCat(core_count());
|
||||||
result["ECC Enabled"] = absl::StrCat(ecc_enabled());
|
result["ECC Enabled"] = absl::StrCat(ecc_enabled());
|
||||||
|
@ -138,6 +138,13 @@ class DeviceDescription {
|
|||||||
// and the return value will be false.
|
// and the return value will be false.
|
||||||
bool rocm_amdgpu_isa_version(int *version) const;
|
bool rocm_amdgpu_isa_version(int *version) const;
|
||||||
|
|
||||||
|
// Returns the
|
||||||
|
// * AMDGPU GCN Architecture Name if we're running on the ROCm platform.
|
||||||
|
// * kUndefinedString otherwise
|
||||||
|
const std::string rocm_amdgpu_gcn_arch_name() const {
|
||||||
|
return rocm_amdgpu_gcn_arch_name_;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the maximum amount of shared memory present on a single core
|
// Returns the maximum amount of shared memory present on a single core
|
||||||
// (i.e. Streaming Multiprocessor on NVIDIA GPUs; Compute Unit for OpenCL
|
// (i.e. Streaming Multiprocessor on NVIDIA GPUs; Compute Unit for OpenCL
|
||||||
// devices). Note that some devices, such as NVIDIA's have a configurable
|
// devices). Note that some devices, such as NVIDIA's have a configurable
|
||||||
@ -203,6 +210,9 @@ class DeviceDescription {
|
|||||||
// ROCM AMDGPU ISA version, 0 if not available.
|
// ROCM AMDGPU ISA version, 0 if not available.
|
||||||
int rocm_amdgpu_isa_version_;
|
int rocm_amdgpu_isa_version_;
|
||||||
|
|
||||||
|
// ROCm AMDGPU GCN Architecture name, "" if not available.
|
||||||
|
std::string rocm_amdgpu_gcn_arch_name_;
|
||||||
|
|
||||||
int numa_node_;
|
int numa_node_;
|
||||||
int core_count_;
|
int core_count_;
|
||||||
bool ecc_enabled_;
|
bool ecc_enabled_;
|
||||||
@ -294,6 +304,10 @@ class DeviceDescriptionBuilder {
|
|||||||
device_description_->rocm_amdgpu_isa_version_ = version;
|
device_description_->rocm_amdgpu_isa_version_ = version;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void set_rocm_amdgpu_gcn_arch_name(const std::string &gcn_arch_name) {
|
||||||
|
device_description_->rocm_amdgpu_gcn_arch_name_ = gcn_arch_name;
|
||||||
|
}
|
||||||
|
|
||||||
void set_numa_node(int value) { device_description_->numa_node_ = value; }
|
void set_numa_node(int value) { device_description_->numa_node_ = value; }
|
||||||
void set_core_count(int value) { device_description_->core_count_ = value; }
|
void set_core_count(int value) { device_description_->core_count_ = value; }
|
||||||
void set_ecc_enabled(bool value) {
|
void set_ecc_enabled(bool value) {
|
||||||
|
@ -460,6 +460,12 @@ class GpuDriver {
|
|||||||
// (supported on ROCm only)
|
// (supported on ROCm only)
|
||||||
static port::Status GetGpuISAVersion(int* version, GpuDeviceHandle device);
|
static port::Status GetGpuISAVersion(int* version, GpuDeviceHandle device);
|
||||||
|
|
||||||
|
// Return the full GCN Architecture Name for the the device
|
||||||
|
// for eg: amdgcn-amd-amdhsa--gfx908:sramecc+:xnack-
|
||||||
|
// (supported on ROCm only)
|
||||||
|
static port::Status GetGpuGCNArchName(GpuDeviceHandle device,
|
||||||
|
std::string* gcnArchName);
|
||||||
|
|
||||||
// Returns the number of multiprocessors on the device (note that the device
|
// Returns the number of multiprocessors on the device (note that the device
|
||||||
// may be multi-GPU-per-board).
|
// may be multi-GPU-per-board).
|
||||||
static port::StatusOr<int> GetMultiprocessorCount(GpuDeviceHandle device);
|
static port::StatusOr<int> GetMultiprocessorCount(GpuDeviceHandle device);
|
||||||
|
@ -1080,6 +1080,21 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
device)};
|
device)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ port::Status GpuDriver::GetGpuGCNArchName(
|
||||||
|
hipDevice_t device, std::string* gcnArchName) {
|
||||||
|
hipDeviceProp_t props;
|
||||||
|
hipError_t result = tensorflow::wrap::hipGetDeviceProperties(&props, device);
|
||||||
|
if (result == hipSuccess) {
|
||||||
|
*gcnArchName = props.gcnArchName;
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
*gcnArchName = "";
|
||||||
|
return port::Status{
|
||||||
|
port::error::INTERNAL,
|
||||||
|
absl::StrFormat("failed to determine AMDGpu GCN Arch Name for device %d",
|
||||||
|
device)};
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function that turns the integer output of hipDeviceGetAttribute to
|
// Helper function that turns the integer output of hipDeviceGetAttribute to
|
||||||
// type T and wraps it in a StatusOr.
|
// type T and wraps it in a StatusOr.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -820,6 +820,12 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) {
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string gcn_arch_name;
|
||||||
|
status = GpuDriver::GetGpuGCNArchName(device, &gcn_arch_name);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
internal::DeviceDescriptionBuilder builder;
|
internal::DeviceDescriptionBuilder builder;
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -888,7 +894,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
builder.set_platform_version(
|
builder.set_platform_version(
|
||||||
absl::StrCat("AMDGPU ISA version: gfx", version));
|
absl::StrCat("AMDGPU ISA version: ", gcn_arch_name));
|
||||||
|
|
||||||
// TODO(leary) should be a way to query this from the driver, but this is
|
// TODO(leary) should be a way to query this from the driver, but this is
|
||||||
// unlikely to change for us any time soon.
|
// unlikely to change for us any time soon.
|
||||||
@ -896,6 +902,8 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) {
|
|||||||
|
|
||||||
builder.set_device_vendor("Advanced Micro Devices, Inc");
|
builder.set_device_vendor("Advanced Micro Devices, Inc");
|
||||||
builder.set_rocm_amdgpu_isa_version(version);
|
builder.set_rocm_amdgpu_isa_version(version);
|
||||||
|
builder.set_rocm_amdgpu_gcn_arch_name(gcn_arch_name);
|
||||||
|
|
||||||
builder.set_shared_memory_per_core(
|
builder.set_shared_memory_per_core(
|
||||||
GpuDriver::GetMaxSharedMemoryPerCore(device).ValueOrDie());
|
GpuDriver::GetMaxSharedMemoryPerCore(device).ValueOrDie());
|
||||||
builder.set_shared_memory_per_block(
|
builder.set_shared_memory_per_block(
|
||||||
|
@ -140,6 +140,7 @@ typedef struct SE_DeviceDescription {
|
|||||||
int cuda_compute_capability_minor;
|
int cuda_compute_capability_minor;
|
||||||
|
|
||||||
int rocm_amdgpu_isa_version;
|
int rocm_amdgpu_isa_version;
|
||||||
|
char* rocm_amdgpu_gcn_arch_name;
|
||||||
|
|
||||||
int numa_node;
|
int numa_node;
|
||||||
int core_count;
|
int core_count;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user