From a4b59d75778e4013acb068e9e4bfdf01a871cd8f Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Tue, 1 Oct 2019 11:30:31 -0700 Subject: [PATCH] Create the XRT compilation cache if does not exist, even in the execute path, so that invalid handles can be flagged with a more meaningfull error. PiperOrigin-RevId: 272256488 --- .../compiler/xrt/kernels/xrt_compile_ops.cc | 14 +++-------- .../compiler/xrt/kernels/xrt_execute_op.cc | 14 ++++------- .../compiler/xrt/xrt_compilation_cache.cc | 24 +++++++++++++++++++ .../compiler/xrt/xrt_compilation_cache.h | 10 ++++++++ 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 89daa98ee18..2ae996bdb0f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -49,8 +49,6 @@ namespace tensorflow { namespace { -const int kDefaultCacheSize = 100; - class XRTCompileOp : public OpKernel { public: explicit XRTCompileOp(OpKernelConstruction* ctx); @@ -159,15 +157,9 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key)); // Process-wide cache of XLA executables. - XRTCompilationCache* cache; - OP_REQUIRES_OK(ctx, - rm->LookupOrCreate( - rm->default_container(), kXRTCompilationCacheResourceName, - &cache, [](XRTCompilationCache** new_cache) { - *new_cache = new XRTCompilationCache(kDefaultCacheSize); - return Status::OK(); - })); - core::ScopedUnref cache_unref(cache); + auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0); + OP_REQUIRES_OK(ctx, cache_or.status()); + auto cache = cache_or.ConsumeValueOrDie(); int64 uid; OP_REQUIRES_OK( diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 1c4e1f7e2c7..a83b035dc04 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -267,11 +267,8 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { bool release_inputs = config_proto.release_input_handles(); bool release_compilation = config_proto.release_compilation_handle(); - XRTCompilationCache* cache; - TF_RETURN_IF_ERROR(rm->Lookup( - rm->default_container(), kXRTCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); - + TF_ASSIGN_OR_RETURN( + auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -350,11 +347,8 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xrt::XRTChainedExecuteConfig config; TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); - XRTCompilationCache* cache; - TF_RETURN_IF_ERROR(rm->Lookup( - rm->default_container(), kXRTCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); - + TF_ASSIGN_OR_RETURN( + auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index 8bf0f28d223..28ee6ff0775 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" +#include + +#include + #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/lib/core/errors.h" @@ -29,6 +33,11 @@ int64 get_uid() { return static_cast(unsigned_rand); } +int64 GetCompilationCacheSizeFromEnv() { + const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE"); + return env == nullptr ? 1024 : std::stol(env); +} + } // namespace const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache"; @@ -277,4 +286,19 @@ string XRTCompilationCache::DebugString() const { return "XRTCompilationCache"; } +xla::StatusOr> GetOrCreateCompilationCache( + ResourceMgr* rm, int64 max_number_of_entries) { + if (max_number_of_entries == 0) { + max_number_of_entries = GetCompilationCacheSizeFromEnv(); + } + XRTCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), kXRTCompilationCacheResourceName, &cache, + [&](XRTCompilationCache** new_cache) { + *new_cache = new XRTCompilationCache(max_number_of_entries); + return Status::OK(); + })); + return RefPtr(cache); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index 7398e847d8b..02cb25ea35c 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -21,6 +21,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xrt/xrt_refptr.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/refcount.h" @@ -231,6 +233,14 @@ class XRTCompilationCache : public ResourceBase { std::map entries_by_last_use_ GUARDED_BY(mu_); }; +// Looks up or create an XRTCompilationCache object within the given resource +// manager, under the default container. The max_number_of_entries sets the +// maximum number of entries within the cache (which will be LRU-evicted). +// If max_number_of_entries is set to sero, the size of the cache will be +// configured using the TF_XRT_COMPILATION_CACHE_SIZE environment variable. +xla::StatusOr> GetOrCreateCompilationCache( + ResourceMgr* rm, int64 max_number_of_entries); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_