Create the XRT compilation cache if does not exist, even in the execute path, so that invalid handles can be flagged with a more meaningfull error.

PiperOrigin-RevId: 272256488
This commit is contained in:
Davide Libenzi 2019-10-01 11:30:31 -07:00 committed by TensorFlower Gardener
parent a86dd1607a
commit a4b59d7577
4 changed files with 41 additions and 21 deletions

View File

@ -49,8 +49,6 @@ namespace tensorflow {
namespace {
const int kDefaultCacheSize = 100;
class XRTCompileOp : public OpKernel {
public:
explicit XRTCompileOp(OpKernelConstruction* ctx);
@ -159,15 +157,9 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
// Process-wide cache of XLA executables.
XRTCompilationCache* cache;
OP_REQUIRES_OK(ctx,
rm->LookupOrCreate<XRTCompilationCache>(
rm->default_container(), kXRTCompilationCacheResourceName,
&cache, [](XRTCompilationCache** new_cache) {
*new_cache = new XRTCompilationCache(kDefaultCacheSize);
return Status::OK();
}));
core::ScopedUnref cache_unref(cache);
auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0);
OP_REQUIRES_OK(ctx, cache_or.status());
auto cache = cache_or.ConsumeValueOrDie();
int64 uid;
OP_REQUIRES_OK(

View File

@ -267,11 +267,8 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
bool release_inputs = config_proto.release_input_handles();
bool release_compilation = config_proto.release_compilation_handle();
XRTCompilationCache* cache;
TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
rm->default_container(), kXRTCompilationCacheResourceName, &cache));
core::ScopedUnref cache_unref(cache);
TF_ASSIGN_OR_RETURN(
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class XRTGenericDeviceAccessor::ScopedRef device_ref;
@ -350,11 +347,8 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
xrt::XRTChainedExecuteConfig config;
TF_RET_CHECK(config.ParseFromString(execution_config.scalar<tstring>()()));
XRTCompilationCache* cache;
TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
rm->default_container(), kXRTCompilationCacheResourceName, &cache));
core::ScopedUnref cache_unref(cache);
TF_ASSIGN_OR_RETURN(
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class XRTGenericDeviceAccessor::ScopedRef device_ref;

View File

@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
#include <stdlib.h>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/lib/core/errors.h"
@ -29,6 +33,11 @@ int64 get_uid() {
return static_cast<int64>(unsigned_rand);
}
int64 GetCompilationCacheSizeFromEnv() {
const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE");
return env == nullptr ? 1024 : std::stol(env);
}
} // namespace
const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
@ -277,4 +286,19 @@ string XRTCompilationCache::DebugString() const {
return "XRTCompilationCache";
}
xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
ResourceMgr* rm, int64 max_number_of_entries) {
if (max_number_of_entries == 0) {
max_number_of_entries = GetCompilationCacheSizeFromEnv();
}
XRTCompilationCache* cache;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XRTCompilationCache>(
rm->default_container(), kXRTCompilationCacheResourceName, &cache,
[&](XRTCompilationCache** new_cache) {
*new_cache = new XRTCompilationCache(max_number_of_entries);
return Status::OK();
}));
return RefPtr<XRTCompilationCache>(cache);
}
} // namespace tensorflow

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xrt/xrt_refptr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/refcount.h"
@ -231,6 +233,14 @@ class XRTCompilationCache : public ResourceBase {
std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_);
};
// Looks up or create an XRTCompilationCache object within the given resource
// manager, under the default container. The max_number_of_entries sets the
// maximum number of entries within the cache (which will be LRU-evicted).
// If max_number_of_entries is set to sero, the size of the cache will be
// configured using the TF_XRT_COMPILATION_CACHE_SIZE environment variable.
xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
ResourceMgr* rm, int64 max_number_of_entries);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_