PR review fixes

This commit is contained in:
Georgiy Manuilov 2020-11-19 12:27:44 +03:00 committed by GitHub
parent 73b7517e09
commit a090692ef8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -108,7 +108,7 @@ static void WarnIfBadPtxasVersion(const std::string& ptxas_path) {
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
using PtxCompilerResult = port::StatusOr<std::vector<tensorflow::uint8>>;
using PtxCompilerResult = port::StatusOr<std::vector<uint8>>;
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
static auto& ptx_cache TF_GUARDED_BY(ptx_cache_mutex) =
*new absl::flat_hash_map<PtxCacheKey, PtxCompilerResult>();
@ -122,13 +122,14 @@ port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
}
CHECK(it != ptx_cache.end());
// Failed compilation attempts are cached.
if (TF_PREDICT_FALSE(!it->second.ok())) {
return it->second.status();
}
CHECK(it != ptx_cache.end());
const std::vector<uint8>& compiled = it->second.ValueOrDie();
TF_ASSIGN_OR_RETURN(const std::vector<uint8>& compiled, it->second);
return absl::MakeSpan(compiled);
}