Changing "GpuVersion" datatype to include hipDeviceProp_t::gcnArchName
Currently the "GpuVersion" datatype (for AMDGPU in XLA code) is an `int`, whose value is the same as the `int hipDeviceProp_t::gcnArch;` field. Starting with ROCm 4.?, which introduces targetID support, that value will no longer be sufficient to create a LLVM AMDGPUTarget that accurately represents underlying GPU. We will need to information contained withing the `string hipDeviceProp_t gcnArchName` field for that purpose. This commit updates fthe "GpuVersion" datatype from being a simple int to a (int, string) pair, and stores the value of the `string hipDeviceProp_t gcnArchName` field in the string.
This commit is contained in:
parent
d236afda36
commit
ab9a9084d2
@ -100,8 +100,14 @@ GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
<< "Couldn't get AMDGPU ISA version for device; assuming gfx803.";
|
||||
isa_version = 803;
|
||||
}
|
||||
std::string gcn_arch_name =
|
||||
stream_exec->GetDeviceDescription().rocm_amdgpu_gcn_arch_name();
|
||||
if (gcn_arch_name == stream_exec->GetDeviceDescription().kUndefinedString) {
|
||||
LOG(WARNING) << "Couldn't get AMDGPU GCN Arch for device; assuming gfx803.";
|
||||
gcn_arch_name = "gfx803";
|
||||
}
|
||||
|
||||
return isa_version;
|
||||
return std::make_pair(isa_version, gcn_arch_name);
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
|
@ -101,10 +101,11 @@ Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
|
||||
int stream_isa_version;
|
||||
main_stream->parent()->GetDeviceDescription().rocm_amdgpu_isa_version(
|
||||
&stream_isa_version);
|
||||
GpuVersion amd_isa_version = stream_isa_version;
|
||||
TF_RET_CHECK(amd_isa_version == gpu_version_)
|
||||
<< "AMDGPU GCN ISA version mismatch; expected {"
|
||||
<< absl::get<int>(gpu_version_) << ", but was " << stream_isa_version;
|
||||
int gpu_exec_isa_version =
|
||||
absl::get<std::pair<int, std::string>>(gpu_version_).first;
|
||||
TF_RET_CHECK(stream_isa_version == gpu_exec_isa_version)
|
||||
<< "AMDGPU GCN ISA version mismatch; expected {" << gpu_exec_isa_version
|
||||
<< ", but was " << stream_isa_version;
|
||||
} else if (platform_kind == stream_executor::PlatformKind::kCuda) {
|
||||
std::pair<int, int> stream_compute_compatibility;
|
||||
main_stream->parent()->GetDeviceDescription().cuda_compute_capability(
|
||||
|
@ -21,10 +21,19 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// GpuVersion is used to abstract Gpu hardware version. On Cuda platform,
|
||||
// it comprises a pair of integers denoting major and minor version.
|
||||
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
|
||||
using GpuVersion = absl::variant<std::pair<int, int>, int>;
|
||||
// GpuVersion is used to abstract Gpu hardware version.
|
||||
//
|
||||
// On Cuda platform, it comprises of an <int, int> pair
|
||||
// denoting major and minor version.
|
||||
//
|
||||
// On ROCm platform, it comprises of an <int, string> pair
|
||||
// the int has the contents of the hipDeviceProp_t::gcnArchValue field.
|
||||
// the string has the contents of the hipDeviceProp_t::gcnArchName field.
|
||||
// The string contains all the information needed to create an exact LLVM
|
||||
// AMDGPUTarget corresopnding the AMDGPU device it represents, the int value
|
||||
// by itself is not sufficient for this purpose
|
||||
using GpuVersion =
|
||||
absl::variant<std::pair<int, int>, std::pair<int, std::string>>;
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -787,13 +787,13 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
|
||||
const HloModuleConfig& hlo_module_config,
|
||||
const string& device_bitcode_dir_path) {
|
||||
// Link the input module with ROCDL.
|
||||
auto amdgpu_version = absl::get_if<int>(&gpu_version);
|
||||
auto amdgpu_version = absl::get_if<std::pair<int, std::string>>(&gpu_version);
|
||||
if (!amdgpu_version) {
|
||||
return xla::InternalError(
|
||||
"Incompatible AMD GCN ISA version was specified.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
LinkROCDLIfNecessary(module, *amdgpu_version, device_bitcode_dir_path));
|
||||
TF_RETURN_IF_ERROR(LinkROCDLIfNecessary(module, amdgpu_version->first,
|
||||
device_bitcode_dir_path));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -861,13 +861,14 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
|
||||
|
||||
auto amdgpu_version = absl::get_if<int>(&gpu_version);
|
||||
auto amdgpu_version =
|
||||
absl::get_if<std::pair<int, std::string>>(&gpu_version);
|
||||
if (!amdgpu_version) {
|
||||
return xla::InternalError(
|
||||
"Incompatible AMD GCN ISA version was specified.");
|
||||
}
|
||||
uint64_t hash;
|
||||
if (HsacoCache::Find(str, hash, *amdgpu_version, hsaco)) {
|
||||
if (HsacoCache::Find(str, hash, amdgpu_version->first, hsaco)) {
|
||||
VLOG(1) << "HSACO cache hit";
|
||||
return hsaco;
|
||||
}
|
||||
@ -885,7 +886,7 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
|
||||
llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
|
||||
// Construct LLVM TargetMachine for AMDGPU.
|
||||
std::unique_ptr<llvm::TargetMachine> target_machine =
|
||||
AMDGPUGetTargetMachine(default_target_triple, *amdgpu_version,
|
||||
AMDGPUGetTargetMachine(default_target_triple, amdgpu_version->first,
|
||||
hlo_module_config);
|
||||
|
||||
// Link with ROCm-Device-Libs, and optimize the LLVM module.
|
||||
@ -896,7 +897,7 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
|
||||
|
||||
// Lower optimized LLVM module to HSA code object.
|
||||
TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
|
||||
HsacoCache::Add(str, hash, *amdgpu_version, hsaco);
|
||||
HsacoCache::Add(str, hash, amdgpu_version->first, hsaco);
|
||||
}
|
||||
return hsaco;
|
||||
}
|
||||
|
@ -53,7 +53,9 @@ class GpuDummyCompiler : public GpuCompiler {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
return std::make_pair(0, 0);
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
const HloModuleConfig& module_config, llvm::Module* llvm_module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user