Merge pull request #44486 from dev0x13:master
PiperOrigin-RevId: 347640445 Change-Id: Ib6116394ff75def312becf733ef445f5fb6db899
This commit is contained in:
commit
67e713e624
@ -108,23 +108,32 @@ static void WarnIfBadPtxasVersion(const std::string& ptxas_path) {
|
||||
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
|
||||
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
|
||||
using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
|
||||
using PtxCompilerResult = port::StatusOr<std::vector<uint8>>;
|
||||
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
|
||||
static auto& ptx_cache TF_GUARDED_BY(ptx_cache_mutex) =
|
||||
*new absl::flat_hash_map<PtxCacheKey, std::vector<uint8>>();
|
||||
*new absl::flat_hash_map<PtxCacheKey, PtxCompilerResult>();
|
||||
|
||||
tensorflow::mutex_lock lock(ptx_cache_mutex);
|
||||
PtxCacheKey cache_key{device_ordinal, std::string(ptx),
|
||||
compilation_options.ToTuple()};
|
||||
auto it = ptx_cache.find(cache_key);
|
||||
if (it == ptx_cache.end()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<uint8> compiled,
|
||||
CompileGpuAsm(device_ordinal, ptx, compilation_options));
|
||||
PtxCompilerResult compiled =
|
||||
CompileGpuAsm(device_ordinal, ptx, compilation_options);
|
||||
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
|
||||
}
|
||||
|
||||
CHECK(it != ptx_cache.end());
|
||||
const std::vector<uint8>& compiled = it->second;
|
||||
|
||||
// Failed compilation attempts are cached.
|
||||
// Use separate status check and ValueOrDie invocation on ptx_cache
|
||||
// entry to avoid value moving introduced by TF_ASSIGN_OR_RETURN.
|
||||
|
||||
if (TF_PREDICT_FALSE(!it->second.ok())) {
|
||||
return it->second.status();
|
||||
}
|
||||
|
||||
const std::vector<uint8>& compiled = it->second.ValueOrDie();
|
||||
return absl::MakeSpan(compiled);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user