Create the XRT compilation cache if does not exist, even in the execute path, so that invalid handles can be flagged with a more meaningfull error.
PiperOrigin-RevId: 272256488
This commit is contained in:
parent
a86dd1607a
commit
a4b59d7577
tensorflow/compiler/xrt
@ -49,8 +49,6 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
const int kDefaultCacheSize = 100;
|
||||
|
||||
class XRTCompileOp : public OpKernel {
|
||||
public:
|
||||
explicit XRTCompileOp(OpKernelConstruction* ctx);
|
||||
@ -159,15 +157,9 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
|
||||
|
||||
// Process-wide cache of XLA executables.
|
||||
XRTCompilationCache* cache;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
rm->LookupOrCreate<XRTCompilationCache>(
|
||||
rm->default_container(), kXRTCompilationCacheResourceName,
|
||||
&cache, [](XRTCompilationCache** new_cache) {
|
||||
*new_cache = new XRTCompilationCache(kDefaultCacheSize);
|
||||
return Status::OK();
|
||||
}));
|
||||
core::ScopedUnref cache_unref(cache);
|
||||
auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0);
|
||||
OP_REQUIRES_OK(ctx, cache_or.status());
|
||||
auto cache = cache_or.ConsumeValueOrDie();
|
||||
|
||||
int64 uid;
|
||||
OP_REQUIRES_OK(
|
||||
|
@ -267,11 +267,8 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
|
||||
bool release_inputs = config_proto.release_input_handles();
|
||||
bool release_compilation = config_proto.release_compilation_handle();
|
||||
|
||||
XRTCompilationCache* cache;
|
||||
TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
|
||||
rm->default_container(), kXRTCompilationCacheResourceName, &cache));
|
||||
core::ScopedUnref cache_unref(cache);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
|
||||
// We are guaranteed that the underlying device object won't be deleted out
|
||||
// from under us, while the ScopedRef is live.
|
||||
class XRTGenericDeviceAccessor::ScopedRef device_ref;
|
||||
@ -350,11 +347,8 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
|
||||
xrt::XRTChainedExecuteConfig config;
|
||||
TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
|
||||
|
||||
XRTCompilationCache* cache;
|
||||
TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
|
||||
rm->default_container(), kXRTCompilationCacheResourceName, &cache));
|
||||
core::ScopedUnref cache_unref(cache);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
|
||||
// We are guaranteed that the underlying device object won't be deleted out
|
||||
// from under us, while the ScopedRef is live.
|
||||
class XRTGenericDeviceAccessor::ScopedRef device_ref;
|
||||
|
@ -15,6 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -29,6 +33,11 @@ int64 get_uid() {
|
||||
return static_cast<int64>(unsigned_rand);
|
||||
}
|
||||
|
||||
int64 GetCompilationCacheSizeFromEnv() {
|
||||
const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE");
|
||||
return env == nullptr ? 1024 : std::stol(env);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
|
||||
@ -277,4 +286,19 @@ string XRTCompilationCache::DebugString() const {
|
||||
return "XRTCompilationCache";
|
||||
}
|
||||
|
||||
xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
|
||||
ResourceMgr* rm, int64 max_number_of_entries) {
|
||||
if (max_number_of_entries == 0) {
|
||||
max_number_of_entries = GetCompilationCacheSizeFromEnv();
|
||||
}
|
||||
XRTCompilationCache* cache;
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XRTCompilationCache>(
|
||||
rm->default_container(), kXRTCompilationCacheResourceName, &cache,
|
||||
[&](XRTCompilationCache** new_cache) {
|
||||
*new_cache = new XRTCompilationCache(max_number_of_entries);
|
||||
return Status::OK();
|
||||
}));
|
||||
return RefPtr<XRTCompilationCache>(cache);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_refptr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
@ -231,6 +233,14 @@ class XRTCompilationCache : public ResourceBase {
|
||||
std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Looks up or create an XRTCompilationCache object within the given resource
|
||||
// manager, under the default container. The max_number_of_entries sets the
|
||||
// maximum number of entries within the cache (which will be LRU-evicted).
|
||||
// If max_number_of_entries is set to sero, the size of the cache will be
|
||||
// configured using the TF_XRT_COMPILATION_CACHE_SIZE environment variable.
|
||||
xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
|
||||
ResourceMgr* rm, int64 max_number_of_entries);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
|
||||
|
Loading…
Reference in New Issue
Block a user