Merge pull request #44486 from dev0x13:master
PiperOrigin-RevId: 347640445 Change-Id: Ib6116394ff75def312becf733ef445f5fb6db899
This commit is contained in:
commit
67e713e624
@ -108,23 +108,32 @@ static void WarnIfBadPtxasVersion(const std::string& ptxas_path) {
|
|||||||
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
|
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
|
||||||
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
|
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
|
||||||
using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
|
using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
|
||||||
|
using PtxCompilerResult = port::StatusOr<std::vector<uint8>>;
|
||||||
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
|
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
|
||||||
static auto& ptx_cache TF_GUARDED_BY(ptx_cache_mutex) =
|
static auto& ptx_cache TF_GUARDED_BY(ptx_cache_mutex) =
|
||||||
*new absl::flat_hash_map<PtxCacheKey, std::vector<uint8>>();
|
*new absl::flat_hash_map<PtxCacheKey, PtxCompilerResult>();
|
||||||
|
|
||||||
tensorflow::mutex_lock lock(ptx_cache_mutex);
|
tensorflow::mutex_lock lock(ptx_cache_mutex);
|
||||||
PtxCacheKey cache_key{device_ordinal, std::string(ptx),
|
PtxCacheKey cache_key{device_ordinal, std::string(ptx),
|
||||||
compilation_options.ToTuple()};
|
compilation_options.ToTuple()};
|
||||||
auto it = ptx_cache.find(cache_key);
|
auto it = ptx_cache.find(cache_key);
|
||||||
if (it == ptx_cache.end()) {
|
if (it == ptx_cache.end()) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
PtxCompilerResult compiled =
|
||||||
std::vector<uint8> compiled,
|
CompileGpuAsm(device_ordinal, ptx, compilation_options);
|
||||||
CompileGpuAsm(device_ordinal, ptx, compilation_options));
|
|
||||||
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
|
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(it != ptx_cache.end());
|
CHECK(it != ptx_cache.end());
|
||||||
const std::vector<uint8>& compiled = it->second;
|
|
||||||
|
// Failed compilation attempts are cached.
|
||||||
|
// Use separate status check and ValueOrDie invocation on ptx_cache
|
||||||
|
// entry to avoid value moving introduced by TF_ASSIGN_OR_RETURN.
|
||||||
|
|
||||||
|
if (TF_PREDICT_FALSE(!it->second.ok())) {
|
||||||
|
return it->second.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<uint8>& compiled = it->second.ValueOrDie();
|
||||||
return absl::MakeSpan(compiled);
|
return absl::MakeSpan(compiled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user