Merge pull request from dev0x13:master

PiperOrigin-RevId: 347640445
Change-Id: Ib6116394ff75def312becf733ef445f5fb6db899
This commit is contained in:
TensorFlower Gardener 2020-12-15 10:31:11 -08:00
commit 67e713e624

View File

@ -108,23 +108,32 @@ static void WarnIfBadPtxasVersion(const std::string& ptxas_path) {
port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
using PtxCompilerResult = port::StatusOr<std::vector<uint8>>;
static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED);
static auto& ptx_cache TF_GUARDED_BY(ptx_cache_mutex) =
*new absl::flat_hash_map<PtxCacheKey, std::vector<uint8>>();
*new absl::flat_hash_map<PtxCacheKey, PtxCompilerResult>();
tensorflow::mutex_lock lock(ptx_cache_mutex);
PtxCacheKey cache_key{device_ordinal, std::string(ptx),
compilation_options.ToTuple()};
auto it = ptx_cache.find(cache_key);
if (it == ptx_cache.end()) {
TF_ASSIGN_OR_RETURN(
std::vector<uint8> compiled,
CompileGpuAsm(device_ordinal, ptx, compilation_options));
PtxCompilerResult compiled =
CompileGpuAsm(device_ordinal, ptx, compilation_options);
it = ptx_cache.emplace(cache_key, std::move(compiled)).first;
}
CHECK(it != ptx_cache.end());
const std::vector<uint8>& compiled = it->second;
// Failed compilation attempts are cached.
// Use separate status check and ValueOrDie invocation on ptx_cache
// entry to avoid value moving introduced by TF_ASSIGN_OR_RETURN.
if (TF_PREDICT_FALSE(!it->second.ok())) {
return it->second.status();
}
const std::vector<uint8>& compiled = it->second.ValueOrDie();
return absl::MakeSpan(compiled);
}