[tf.data] Memory-safe implementation of sharing access to the memory cache.
PiperOrigin-RevId: 307736215 Change-Id: If10ef65e6706a106e6bb4fc2d6fe4542bbe056cc
This commit is contained in:
parent
df7e7b1617
commit
b546b463b5
@ -694,8 +694,10 @@ Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
|
||||
class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
public:
|
||||
explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache)
|
||||
: DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
|
||||
std::shared_ptr<MemoryCache> cache)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
cache_(std::move(cache)) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -708,7 +710,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
return absl::make_unique<MemoryIterator>(
|
||||
MemoryIterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
|
||||
cache_);
|
||||
cache_.get());
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
@ -964,7 +966,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
}; // MemoryIterator
|
||||
|
||||
const DatasetBase* const input_;
|
||||
MemoryCache* const cache_;
|
||||
const std::shared_ptr<MemoryCache> cache_;
|
||||
}; // MemoryDatasetBase
|
||||
|
||||
// This version of memory dataset has an exclusive ownership of the memory cache
|
||||
@ -973,22 +975,19 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
|
||||
public:
|
||||
MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache, const ResourceHandle& resource_handle)
|
||||
: MemoryDatasetBase(ctx, input, cache),
|
||||
resource_handle_(resource_handle) {
|
||||
cleanup_ = [this, mgr = ctx->resource_manager()]() {
|
||||
DCHECK(cache_->RefCountIsOne());
|
||||
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
|
||||
resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||
}
|
||||
};
|
||||
}
|
||||
MemoryCacheManager* manager, ResourceHandle&& resource_handle)
|
||||
: MemoryDatasetBase(ctx, input, manager->get()),
|
||||
manager_(manager),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()) {}
|
||||
|
||||
~MemoryDataset() override {
|
||||
cache_->Unref();
|
||||
cleanup_();
|
||||
manager_->Unref();
|
||||
Status s = resource_mgr_->Delete<MemoryCacheManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -1005,8 +1004,9 @@ class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void()> cleanup_;
|
||||
MemoryCacheManager* const manager_; // Owned.
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
};
|
||||
|
||||
// This version of memory dataset has a shared ownership of the memory cache
|
||||
@ -1016,28 +1016,23 @@ class CacheDatasetOp::MemoryDatasetV2
|
||||
: public CacheDatasetOp::MemoryDatasetBase {
|
||||
public:
|
||||
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
|
||||
MemoryCache* cache, const ResourceHandle& resource_handle)
|
||||
: MemoryDatasetBase(ctx, input, cache),
|
||||
resource_handle_(std::move(resource_handle)) {
|
||||
cleanup_ = [this, mgr = ctx->resource_manager()]() {
|
||||
if (cache_->RefCountIsOne()) {
|
||||
Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
|
||||
resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
if (errors::IsNotFound(s)) {
|
||||
// This is a bening race resulting from concurrent deletion.
|
||||
VLOG(1) << "Failed to delete cache resource: " << s.ToString();
|
||||
} else {
|
||||
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
MemoryCacheManager* manager, ResourceHandle&& resource_handle,
|
||||
bool owns_resource)
|
||||
: MemoryDatasetBase(ctx, input, manager->get()),
|
||||
manager_(manager),
|
||||
owns_resource_(owns_resource),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()) {}
|
||||
|
||||
~MemoryDatasetV2() override {
|
||||
cache_->Unref();
|
||||
cleanup_();
|
||||
manager_->Unref();
|
||||
if (owns_resource_) {
|
||||
Status s = resource_mgr_->Delete<MemoryCacheManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -1058,8 +1053,10 @@ class CacheDatasetOp::MemoryDatasetV2
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void()> cleanup_;
|
||||
MemoryCacheManager* const manager_; // Owned.
|
||||
const bool owns_resource_;
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
};
|
||||
|
||||
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
|
||||
@ -1077,33 +1074,39 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
if (op_version_ == 2) {
|
||||
MemoryCache* cache = nullptr;
|
||||
bool owns_resource = false;
|
||||
MemoryCacheManager* manager = nullptr;
|
||||
auto handle = HandleFromInput(ctx, 2);
|
||||
Status s = ctx->resource_manager()->Lookup<MemoryCache>(
|
||||
handle.container(), handle.name(), &cache);
|
||||
Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
|
||||
handle.container(), handle.name(), &manager);
|
||||
if (errors::IsNotFound(s)) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<MemoryCache>(
|
||||
container, name, &cache, [](MemoryCache** cache) {
|
||||
*cache = new MemoryCache();
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
|
||||
container, name, &manager, [](MemoryCacheManager** manager) {
|
||||
*manager = new MemoryCacheManager();
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
// Ownership of cache is transferred onto `MemoryDatasetV2`.
|
||||
*output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
|
||||
// Ownership of manager is transferred onto `MemoryDatasetV2`.
|
||||
*output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
|
||||
owns_resource);
|
||||
} else {
|
||||
MemoryCache* cache;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
|
||||
container, name, &cache, [](MemoryCache** cache) {
|
||||
*cache = new MemoryCache();
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
|
||||
// Ownership of cache is transferred onto `MemoryDataset`.
|
||||
*output = new MemoryDataset(ctx, input, cache, handle);
|
||||
MemoryCacheManager* manager;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
|
||||
container, name, &manager, [](MemoryCacheManager** manager) {
|
||||
*manager = new MemoryCacheManager();
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle =
|
||||
MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
|
||||
// Ownership of manager is transferred onto `MemoryDataset`.
|
||||
*output = new MemoryDataset(ctx, input, manager, std::move(handle));
|
||||
}
|
||||
} else {
|
||||
if (op_version_ == 2) {
|
||||
|
@ -31,7 +31,7 @@ constexpr char kMemoryCache[] = "MemoryCache";
|
||||
|
||||
} // namespace
|
||||
|
||||
string MemoryCache::DebugString() const { return kMemoryCache; }
|
||||
string MemoryCacheManager::DebugString() const { return kMemoryCache; }
|
||||
|
||||
void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
|
||||
mutex_lock l(mu_);
|
||||
@ -65,19 +65,15 @@ size_t MemoryCache::size() {
|
||||
|
||||
AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: AnonymousResourceOp<MemoryCache>(ctx) {}
|
||||
|
||||
void AnonymousMemoryCacheHandleOp::Compute(OpKernelContext* ctx) {
|
||||
AnonymousResourceOp<MemoryCache>::Compute(ctx);
|
||||
}
|
||||
: AnonymousResourceOp<MemoryCacheManager>(ctx) {}
|
||||
|
||||
string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; }
|
||||
|
||||
Status AnonymousMemoryCacheHandleOp::CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib, MemoryCache** resource) {
|
||||
*resource = new MemoryCache();
|
||||
FunctionLibraryRuntime* lib, MemoryCacheManager** manager) {
|
||||
*manager = new MemoryCacheManager();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -27,12 +27,10 @@ namespace data {
|
||||
// The expected use is that a single `MemoryWriterIterator` populates the
|
||||
// cache with dataset elements. Once all elements are cached, the cache can
|
||||
// be used by one or more `MemoryReaderIterator`s.
|
||||
class MemoryCache : public ResourceBase {
|
||||
class MemoryCache {
|
||||
public:
|
||||
MemoryCache() = default;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
// Marks the cache as completed.
|
||||
void Complete(std::vector<std::vector<Tensor>>&& cache);
|
||||
|
||||
@ -55,11 +53,24 @@ class MemoryCache : public ResourceBase {
|
||||
std::vector<std::vector<Tensor>> cache_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// A resource wrapping a shared instance of a memory cache.
|
||||
class MemoryCacheManager : public ResourceBase {
|
||||
public:
|
||||
MemoryCacheManager() : cache_(std::make_shared<MemoryCache>()) {}
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
std::shared_ptr<MemoryCache> get() { return cache_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<MemoryCache> cache_;
|
||||
};
|
||||
|
||||
// Creates an instance of cache resource and transfers ownership to the caller.
|
||||
class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
|
||||
class AnonymousMemoryCacheHandleOp
|
||||
: public AnonymousResourceOp<MemoryCacheManager> {
|
||||
public:
|
||||
explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
string name() override;
|
||||
@ -67,7 +78,7 @@ class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib,
|
||||
MemoryCache** resource) override;
|
||||
MemoryCacheManager** manager) override;
|
||||
};
|
||||
|
||||
// Deletes an instance of cache resource.
|
||||
|
Loading…
Reference in New Issue
Block a user